diff --git a/.gitattributes b/.gitattributes index c9cee8764dcb42ccb10f4d362a8bc1230bf18fba..617adfb2762a4a28e53ce2cc8a3275e406028a28 100644 --- a/.gitattributes +++ b/.gitattributes @@ -10081,3 +10081,480 @@ var/VAR_dd/visualize/attn_score/15/class_980/seed_1/raw_map_map_6.jpg filter=lfs var/VAR_dd/visualize/attn_score/15/class_980/seed_1/raw_map_map_7.jpg filter=lfs diff=lfs merge=lfs -text var/VAR_dd/visualize/attn_score/15/class_980/seed_1/raw_map_map_8.jpg filter=lfs diff=lfs merge=lfs -text var/VAR_dd/visualize/attn_score/15/class_980/seed_1/raw_map_map_9.jpg filter=lfs diff=lfs merge=lfs -text +MRI_recon/code/Frequency-Diffusion/FSMNet/debug/True_0_0.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/figures/FSMNet.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-0-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-0-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-1-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-1-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-2-compare.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-2-middle.png filter=lfs diff=lfs merge=lfs -text +MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/figures/FSMNet.png filter=lfs diff=lfs merge=lfs -text diff --git a/MRI_recon/code/Frequency-Diffusion/.gitignore b/MRI_recon/code/Frequency-Diffusion/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d1c3ac773ddd4815dcc056d9671582ad12f2295f --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/.gitignore @@ -0,0 +1,142 @@ +# Generation results +results/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +log./ +log.txt +.log + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +*.png +*.pth +# Translations +*.mo +*.pot + +# Django stuff: +*.log +log/ +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ +.DS_Store +.idea/ +apex diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/LICENSE b/MRI_recon/code/Frequency-Diffusion/FSMNet/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/README.md b/MRI_recon/code/Frequency-Diffusion/FSMNet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8f9aaadef4dd0210e6f11eb09f082c241e08051e --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/README.md @@ -0,0 +1,97 @@ +# FSMNet +FSMNet efficiently explores global dependencies across different modalities. Specifically, the features for each modality are extracted by the Frequency-Spatial Feature Extraction (FSFE) module, featuring a frequency branch and a spatial branch. Benefiting from the global property of the Fourier transform, the frequency branch can efficiently capture global dependency with an image-size receptive field, while the spatial branch can extract local features. To exploit complementary information from the auxiliary modality, we propose a Cross-Modal Selective fusion (CMS-fusion) module that selectively incorporate the frequency and spatial features from the auxiliary modality to enhance the corresponding branch of the target modality. To further integrate the enhanced global features from the frequency branch and the enhanced local features from the spatial branch, we develop a Frequency-Spatial fusion (FS-fusion) module, resulting in a comprehensive feature representation for the target modality. + +

+ +## Paper + +Accelerated Multi-Contrast MRI Reconstruction via Frequency and Spatial Mutual Learning
+[Qi Chen](https://scholar.google.com/citations?user=4Q5gs2MAAAAJ&hl=en)1, [Xiaohan Xing](https://hathawayxxh.github.io/)2, *, [Zhen Chen](https://franciszchen.github.io/)3, [Zhiwei Xiong](http://staff.ustc.edu.cn/~zwxiong/)1
+1 University of Science and Technology of China,
+2 Stanford University,
+3 Centre for Artificial Intelligence and Robotics (CAIR), HKISI-CAS
+MICCAI, 2024
+[paper](http://arxiv.org/abs/2409.14113) | [code](https://github.com/qic999/FSMNet) | [huggingface](https://huggingface.co/datasets/qicq1c/MRI_Reconstruction) + +## 0. Installation + +```bash +git clone https://github.com/qic999/FSMNet.git +cd FSMNet +``` + +See [installation instructions](documents/INSTALL.md) to create an environment and obtain requirements. + +## 1. Prepare datasets +Download BraTS dataset and fastMRI dataset and save them to the `datapath` directory. +``` +cd $datapath +# download brats dataset +wget https://huggingface.co/datasets/qicq1c/MRI_Reconstruction/resolve/main/BRATS_100patients.zip +unzip BRATS_100patients.zip +# download fastmri dataset +wget https://huggingface.co/datasets/qicq1c/MRI_Reconstruction/resolve/main/singlecoil_train_selected.zip +unzip singlecoil_train_selected.zip +``` + +## 2. Training +##### BraTS dataset, AF=4 +``` +python train_brats.py --root_path /data/qic99/MRI_recon image_100patients_4X/ \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --MRIDOWN 4X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_4x +``` + +##### BraTS dataset, AF=8 +``` +python train_brats.py --root_path /data/qic99/MRI_recon/image_100patients_8X/ \ + --gpu 1 --batch_size 4 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_8x +``` + +##### fastMRI dataset, AF=4 +``` +python train_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x +``` + +##### fastMRI dataset, AF=8 +``` +python train_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 1 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x +``` + +## 3. Testing +##### BraTS dataset, AF=4 +``` +python test_brats.py --root_path /data/qic99/MRI_recon/image_100patients_4X/ \ + --gpu 3 --base_lr 0.0001 --MRIDOWN 4X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_4x --phase test +``` + +##### BraTS dataset, AF=8 +``` +python test_brats.py --root_path /data/qic99/MRI_recon/image_100patients_8X/ \ + --gpu 4 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_8x --phase test +``` + +##### fastMRI dataset, AF=4 +``` +python test_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 5 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x --phase test +``` + +##### fastMRI dataset, AF=8 +``` +python test_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 6 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x --phase test +``` \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/__init__.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/bash/brats.sh b/MRI_recon/code/Frequency-Diffusion/FSMNet/bash/brats.sh new file mode 100644 index 0000000000000000000000000000000000000000..66b4f9009b7d83cc3eed5c91ee095437c119fd80 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/bash/brats.sh @@ -0,0 +1,46 @@ +# BraTS dataset, AF=4 +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion/FSMNet-modify +cd /bask/projects/j/jiaoj-rep-learn/Hao/repo/Frequency-Diffusion/FSMNet-modify + + +gamedrive=/media/cbtil3/74ec35fd-2452-4dcc-8d7d-3ba957e302c9 + +#4T folder: /media/cbtil3/9feaf350-913e-4def-8114-f03573c04364/hao +root_path_4x=/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee/image_100patients_4X/ +root_path_8x=/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee/image_100patients_8X/ + +root_path_4x=$gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_4X/ +root_path_8x=$gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_8X/ + +python train_brats.py --root_path $root_path_4x\ + --gpu 0 --batch_size 4 --base_lr 0.0001 --MRIDOWN 4X --low_field_SNR 0 \ + --input_normalize mean_std --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_BraTS_4x --use_time_model --use_kspace + +# BraTS dataset, AF=8 +python train_brats.py --root_path $root_path_8x \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8\ + --exp FSMNet_BraTS_8x --use_time_model --use_kspace + + + + +# Test +#BraTS dataset, AF=4 + +python test_brats.py --root_path $root_path_4x \ + --gpu 0 --base_lr 0.0001 --MRIDOWN 4X --low_field_SNR 0 \ + --input_normalize mean_std --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_BraTS_4x --phase test --use_time_model --use_kspace + + + +#BraTS dataset, AF=8 + +python test_brats.py --root_path $root_path_8x \ + --gpu 1 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_BraTS_8x --phase test --use_time_model --use_kspace + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/bash/fastmri.sh b/MRI_recon/code/Frequency-Diffusion/FSMNet/bash/fastmri.sh new file mode 100644 index 0000000000000000000000000000000000000000..ac5f270a64aebbd4b9b424901e07521aa03bc6dc --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/bash/fastmri.sh @@ -0,0 +1,59 @@ + +cd /home/v-qichen3/MRI_recon/code/Frequency-Diffusion/FSMNet + +data_root=/home/v-qichen3/MRI_recon/data/fastmri + + +python train_fastmri.py --root_path $data_root \ + --gpu 2 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x --MRIDOWN 4X --MASKTYPE random \ + --num_timesteps 5 --image_size 320 --use_kspace --use_time_model + +# fastMRI dataset, AF=8 + +# python train_fastmri.py --root_path $data_root \ +# --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ +# --exp FSMNet_fastmri_8x --MRIDOWN 8X --MASKTYPE equispaced \ +# --num_timesteps 5 --image_size 320 --use_kspace --use_time_model + + +# python train_fastmri.py --root_path $data_root \ +# --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.03 --ACCELERATIONS 12 \ +# --exp FSMNet_fastmri_12x --MRIDOWN 12X --MASKTYPE equispaced \ +# --num_timesteps 5 --image_size 320 --use_kspace --use_time_model + + + +# # Test +# #fastMRI dataset, AF=4 +# model_4x=model/FSMNet_fastmri_4x/iter_100000.pth + +python test_fastmri.py --root_path $data_root \ + --gpu 1 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x --phase test --MRIDOWN 4X \ + --num_timesteps 5 --image_size 320 --use_kspace --use_time_model --test_sample Ksample --snapshot_path /home/v-qichen3/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion + + + + # ColdDiffusion DDPM + +# #fastMRI dataset, AF=8 + +python test_fastmri.py --root_path $data_root \ + --gpu 3 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x --phase test --MRIDOWN 8X \ + --num_timesteps 5 --image_size 320 --use_kspace --use_time_model --test_sample Ksample --snapshot_path /home/v-qichen3/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion + + + # ColdDiffusion DDPM + + +# python test_fastmri.py --root_path $data_root \ +# --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.03 --ACCELERATIONS 12 \ +# --exp FSMNet_fastmri_12x --phase test --MRIDOWN 12X \ +# --num_timesteps 5 --image_size 320 --use_kspace --use_time_model --test_sample Ksample # ColdDiffusion DDPM + + + +# # FSMNet_fastmri_12x_t30_kspace_time + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/bash/fastmri_8x.sh b/MRI_recon/code/Frequency-Diffusion/FSMNet/bash/fastmri_8x.sh new file mode 100644 index 0000000000000000000000000000000000000000..9acc97c09e8da1fed7286b9a56a3f62478d80798 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/bash/fastmri_8x.sh @@ -0,0 +1,52 @@ + +cd /home/v-qichen3/MRI_recon/code/Frequency-Diffusion/FSMNet + +data_root=/home/v-qichen3/MRI_recon/data/fastmri + + +# python train_fastmri.py --root_path $data_root \ +# --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ +# --exp FSMNet_fastmri_4x --MRIDOWN 4X --MASKTYPE random \ +# --num_timesteps 5 --image_size 320 --use_kspace --use_time_model + +# fastMRI dataset, AF=8 + +python train_fastmri.py --root_path $data_root \ + --gpu 3 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x --MRIDOWN 8X --MASKTYPE equispaced \ + --num_timesteps 5 --image_size 320 --use_kspace --use_time_model + + +# python train_fastmri.py --root_path $data_root \ +# --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.03 --ACCELERATIONS 12 \ +# --exp FSMNet_fastmri_12x --MRIDOWN 12X --MASKTYPE equispaced \ +# --num_timesteps 5 --image_size 320 --use_kspace --use_time_model + + + +# # Test +# #fastMRI dataset, AF=4 +# model_4x=model/FSMNet_fastmri_4x/iter_100000.pth + +# python test_fastmri.py --root_path $data_root \ +# --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ +# --exp FSMNet_fastmri_4x --phase test --MRIDOWN 4X \ +# --num_timesteps 5 --image_size 320 --use_kspace --use_time_model --test_sample Ksample # ColdDiffusion DDPM + +# #fastMRI dataset, AF=8 + +# python test_fastmri.py --root_path $data_root \ +# --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ +# --exp FSMNet_fastmri_8x --phase test --MRIDOWN 8X \ +# --num_timesteps 5 --image_size 320 --use_kspace --use_time_model --test_sample Ksample # ColdDiffusion DDPM + + +# python test_fastmri.py --root_path $data_root \ +# --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.03 --ACCELERATIONS 12 \ +# --exp FSMNet_fastmri_12x --phase test --MRIDOWN 12X \ +# --num_timesteps 5 --image_size 320 --use_kspace --use_time_model --test_sample Ksample # ColdDiffusion DDPM + + + +# # FSMNet_fastmri_12x_t30_kspace_time + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/bash/m4raw.sh b/MRI_recon/code/Frequency-Diffusion/FSMNet/bash/m4raw.sh new file mode 100644 index 0000000000000000000000000000000000000000..0f1b1197703e3ff89efaf9c162280f5c9fe07b89 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/bash/m4raw.sh @@ -0,0 +1,149 @@ + +cd /home/v-qichen3/MRI_recon/code/Frequency-Diffusion/FSMNet/ + +data_root=/home/v-qichen3/MRI_recon/data/m4raw + + +python train_m4raw.py --root_path $data_root \ + --gpu 3 --batch_size 4 --base_lr 0.0005 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_m4raw_4x_lr5e-4 --MRIDOWN 4X --MASKTYPE random \ + --num_timesteps 30 --image_size 240 --use_kspace --use_time_model + + +# m4raw dataset, AF=8 +# python train_m4raw.py --root_path $data_root \ +# --gpu 0 --batch_size 8 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ +# --exp FSMNet_m4raw_8x --MRIDOWN 8X --MASKTYPE equispaced \ +# --num_timesteps 5 --image_size 240 --use_kspace --use_time_model + + +# data_root=/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee +# python train_m4raw.py --root_path $data_root \ +# --gpu 1 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.03 --ACCELERATIONS 12 \ +# --exp FSMNet_m4raw_12x --MRIDOWN 12X --MASKTYPE equispaced \ +# --num_timesteps 30 --image_size 240 --use_kspace --use_time_model + + +# ---------------- Test ---------------- + +python test_m4raw.py --root_path $data_root \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_m4raw_4x --phase test --MASKTYPE random --MRIDOWN 4X \ + --num_timesteps 30 --image_size 240 --use_kspace --use_time_model --test_tag no_distortion \ + --test_sample ColdDiffusion --snapshot_path /home/v-qichen3/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time + +#m4raw dataset, AF=8 +# python test_m4raw.py --root_path $data_root \ +# --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ +# --exp FSMNet_m4raw_8x --phase test --MASKTYPE equispaced --MRIDOWN 8X \ +# --num_timesteps 5 --image_size 240 --use_kspace --use_time_model --test_sample Ksample # ColdDiffusion DDPM + + + +# python test_m4raw.py --root_path $data_root \ +# --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.03 --ACCELERATIONS 12 \ +# --exp FSMNet_m4raw_12x --phase test --MASKTYPE equispaced --MRIDOWN 12X \ +# --num_timesteps 5 --image_size 240 --use_kspace --use_time_model --test_sample Ksample # ColdDiffusion DDPM + + + +# ------------------------------------ +# NMSE: 3.2228 ± 0.2863 +# PSNR: 28.8534 ± 0.4793 +# SSIM: 0.7769 ± 0.0126 +# ------------------------------------ +# All NMSE: 3.2175 ± 0.5230 +# All PSNR: 27.8081 ± 0.7997 +# All SSIM: 0.7507 ± 0.0218 +# ------------------------------------ +# Save Path: model/FSMNet_m4raw_4x_t5_new_kspace_no_distortion//result_case/ + +# ------------------------------------ +# NMSE: 2.9359 ± 0.2239 +# PSNR: 29.2540 ± 0.4309 +# SSIM: 0.7922 ± 0.0116 +# ------------------------------------ +# DDPM NMSE: 9.8704 ± 0.5135 +# DDPM PSNR: 23.9812 ± 0.2724 +# DDPM SSIM: 0.6568 ± 0.0124 +# ------------------------------------ +# Save Path: model/FSMNet_m4raw_4x_t5_new_kspace_no_distortion_time/result_case/ + +# ------------------------------------ +# NMSE: 2.2551 ± 0.1361 +# PSNR: 30.3949 ± 0.4174 +# SSIM: 0.8167 ± 0.0126 +# ------------------------------------ +# Save Path: FSMNet_m4raw_4x_t5_new_kspace_time + +# ------------------------------------ +# NMSE: 3.8027 ± 0.3095 +# PSNR: 27.7977 ± 0.4213 +# SSIM: 0.7529 ± 0.0119 +# ------------------------------------ +# Save Path: FSMNet_m4raw_4x + +# ------------------------------------ +# NMSE: 6.6898 ± 0.6444 +# PSNR: 25.6848 ± 0.4866 +# SSIM: 0.6844 ± 0.0149 +# ------------------------------------ +# ColdDiffusion NMSE: 8.4526 ± 0.7364 +# ColdDiffusion PSNR: 24.6651 ± 0.4320 +# ColdDiffusion SSIM: 0.6489 ± 0.0141 +# ------------------------------------ +# Save Path: FSMNet_m4raw_8x_t5_new_kspace_time + +# ------------------------------------ +# NMSE: 7.8239 ± 0.7702 +# PSNR: 24.1138 ± 0.5690 +# SSIM: 0.6421 ± 0.0164 +# ------------------------------------ +# Save Path: FSMNet_m4raw_8x + + +# ------------------------------------ +# NMSE: 7.9649 ± 0.7895 +# PSNR: 24.9283 ± 0.4875 +# SSIM: 0.6548 ± 0.0150 +# ------------------------------------ +# All NMSE: 7.9485 ± 1.4481 +# All PSNR: 23.8971 ± 0.8682 +# All SSIM: 0.6231 ± 0.0348 +# ------------------------------------ +# Save Path: model/FSMNet_m4raw_8x_t5_new_kspace_no_distortion//result_case/ + + +# ------------------------------------ +# NMSE: 7.4375 ± 0.7510 +# PSNR: 25.2266 ± 0.4907 +# SSIM: 0.6662 ± 0.0140 +# ------------------------------------ +# All NMSE: 7.4210 ± 1.3740 +# All PSNR: 24.1975 ± 0.8648 +# All SSIM: 0.6353 ± 0.0317 +# ------------------------------------ +# Save Path: model/FSMNet_m4raw_8x_t5_new_kspace_time_no_distortion//result_case/ + + + +# ------------------------------------ +# NMSE: 9.9329 ± 0.8853 +# PSNR: 23.9651 ± 0.4525 +# SSIM: 0.6380 ± 0.0161 +# ------------------------------------ +# Save Path: model/FSMNet_m4raw_12x_t5_new_kspace_time//result_case/ + + +# ------------------------------------ +# NMSE: 9.4328 ± 0.9025 +# PSNR: 23.1375 ± 0.5578 +# SSIM: 0.6096 ± 0.0188 +# ------------------------------------ +# All NMSE: 9.4449 ± 1.9331 +# All PSNR: 22.2995 ± 1.0077 +# All SSIM: 0.5807 ± 0.0359 +# ------------------------------------ +# Save Path: model/FSMNet_m4raw_12x//result_case/ + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/BRATS_DuDo_dataloader.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/BRATS_DuDo_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..e2b06691ee683a347d4a20948d03598db65e9c08 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/BRATS_DuDo_dataloader.py @@ -0,0 +1,295 @@ +""" +dual-domain network的dataloader, 读取两个模态的under-sampled和fully-sampled kspace data, 以及high-quality image作为监督信号。 +""" + + +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import cv2 +import os +import torch +from torch.utils.data import Dataset + +from .kspace_subsample import undersample_mri, mri_fft + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + + return data + + + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, HF_refine = 'False', split='train', MRIDOWN='4X', SNR=15, \ + transform=None, input_round = None, input_normalize=None): + + super().__init__() + self._base_dir = base_dir + self.HF_refine = HF_refine + self.input_round = input_round + self._MRIDOWN = MRIDOWN + self._SNR = SNR + self.im_ids = [] + self.t2_images = [] + self.splits_path = "/data/xiaohan/BRATS_dataset/cv_splits_100patients/" + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + # self.test_file = self.splits_path + 'train_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + # test_images = os.listdir(self._base_dir) + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + self.t2_images.append(t2_path) + + + self.transform = transform + self.input_normalize = input_normalize + + assert (len(self.t1_images) == len(self.t2_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + + image_name = self.t1_images[index].split('t1')[0] + # print("image name:", image_name) + + t1 = np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0 + t2 = np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0 + # print("images:", t1_in.shape, t1.shape, t2_in.shape, t2.shape) + # print("t1 before standardization:", t1.max(), t1.min(), t1.mean()) + # print("loaded t1 range:", t1.max(), t1.min()) + # print("loaded t2 range:", t2.max(), t2 .min()) + + ### normalize the MRI image by divide_max + t1_max, t2_max = t1.max(), t2.max() + t1 = t1/t1_max + t2 = t2/t2_max + sample_stats = {"t1_max": t1_max, "t2_max": t2_max, "image_name": image_name} + + # sample_stats = {"t1_max": 1.0, "t2_max": 1.0} + + ### convert images to kspace and perform undersampling. + t1_kspace_in, t1_in, t1_kspace, t1_img = mri_fft(t1, _SNR = self._SNR) + t2_kspace_in, t2_in, t2_kspace, t2_img, mask = undersample_mri( + t2, _MRIDOWN = self._MRIDOWN, _SNR = self._SNR) + + + # print("loaded t2 range:", t2.max(), t2.min()) + # print("t2_under_img range:", t2_under_img.max(), t2_under_img.min()) + # print("t2_kspace real_part range:", t2_kspace.real.max(), t2_kspace.real.min()) + # print("t2_kspace imaginary_part range:", t2_kspace.imag.max(), t2_kspace.imag.min()) + # print("t2_kspace_in real_part range:", t2_kspace_in.real.max(), t2_kspace_in.real.min()) + # print("t2_kspace_in imaginary_part range:", t2_kspace_in.imag.max(), t2_kspace_in.imag.min()) + + if self.HF_refine == "False": + sample = {'t1': t1_img, 't1_in': t1_in, 't1_kspace': t1_kspace, 't1_kspace_in': t1_kspace_in, \ + 't2': t2_img, 't2_in': t2_in, 't2_kspace': t2_kspace, 't2_kspace_in': t2_kspace_in, \ + 't2_mask': mask} + + elif self.HF_refine == "True": + ### 读取上一步重建的kspace data. + t1_krecon_path = self._base_dir + self.t1_images[index].replace( + 't1.png', 't1_' + str(self._SNR) + 'dB_recon_kspace_' + self.input_round + '_DudoLoss.npy') + t2_krecon_path = self._base_dir + self.t1_images[index].replace('t1.png', 't2_' + self._MRIDOWN + \ + '_' + str(self._SNR) + 'dB_recon_kspace_' + self.input_round + '_DudoLoss.npy') + + t1_krecon = np.load(t1_krecon_path) + t2_krecon = np.load(t2_krecon_path) + # print("t1 and t2 recon kspace:", t1_krecon.shape, t2_krecon.shape) + # + sample = {'t1': t1_img, 't1_in': t1_in, 't1_kspace': t1_kspace, 't1_kspace_in': t1_kspace_in, \ + 't2': t2_img, 't2_in': t2_in, 't2_kspace': t2_kspace, 't2_kspace_in': t2_kspace_in, \ + 't2_mask': mask, 't1_krecon': t1_krecon, 't2_krecon': t2_krecon} + + + if self.transform is not None: + sample = self.transform(sample) + + return sample, sample_stats + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + # print("img_in before_numpy range:", img_in.max(), img_in.min()) + img = torch.from_numpy(img).float() + target = torch.from_numpy(target).float() + # print("img_in range:", img_in.max(), img_in.min()) + + return {'ct': img, 'mri': target} + + +# class ToTensor(object): +# """Convert ndarrays in sample to Tensors.""" + +# def __call__(self, sample): +# # swap color axis because +# # numpy image: H x W x C +# # torch image: C X H X W +# img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) +# img = sample['image'][:, :, None].transpose((2, 0, 1)) +# target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) +# target = sample['target'][:, :, None].transpose((2, 0, 1)) +# # print("img_in before_numpy range:", img_in.max(), img_in.min()) +# img_in = torch.from_numpy(img_in).float() +# img = torch.from_numpy(img).float() +# target_in = torch.from_numpy(target_in).float() +# target = torch.from_numpy(target).float() +# # print("img_in range:", img_in.max(), img_in.min()) + +# return {'ct_in': img_in, +# 'ct': img, +# 'mri_in': target_in, +# 'mri': target} diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/BRATS_dataloader.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/BRATS_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..28d52593311a0bbe1c679fc0687cbe949e85dc7c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/BRATS_dataloader.py @@ -0,0 +1,174 @@ +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import os +import torch +from torch.utils.data import Dataset + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', MRIDOWN='4X', SNR=15, transform=None): + + super().__init__() + self._base_dir = base_dir + self._MRIDOWN = MRIDOWN + self.im_ids = [] + self.t2_images = [] + self.t1_undermri_images, self.t2_undermri_images = [], [] + self.splits_path = "/home/xiaohan/datasets/BRATS_dataset/BRATS_2020_images/cv_splits/" + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + # test_images = os.listdir(self._base_dir) + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + if SNR == 0: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_undermri') + t1_under_path = image_path + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + else: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + t1_under_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + if MRIDOWN == "False": + t2_under_path = image_path.replace('t1', 't2_' + str(SNR) + 'dB') + else: + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + + # print("image paths:", image_path, t1_under_path, t2_path, t2_under_path) + + self.t2_images.append(t2_path) + self.t1_undermri_images.append(t1_under_path) + self.t2_undermri_images.append(t2_under_path) + + # print("t1 images:", self.t1_images) + # print("t2 images:", self.t2_images) + # print("t1_undermri_images:", self.t1_undermri_images) + # print("t2_undermri_images:", self.t2_undermri_images) + + self.transform = transform + + assert (len(self.t1_images) == len(self.t2_images)) + assert (len(self.t1_images) == len(self.t1_undermri_images)) + assert (len(self.t1_images) == len(self.t2_undermri_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + ### 两种settings. + ### 1. T1 fully-sampled 不加noise, T2 down-sampled, 做MRI acceleration. + ### 2. T1 fully-sampled 但是加noise, T2 down-sampled同时也加noise, 同时做MRI acceleration and enhancement. + ### T1, T2两个模态的输入都是low-quality images. + sample = {'image_in': np.array(Image.open(self._base_dir + self.t1_undermri_images[index]))/255.0, + 'image': np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0, + 'target_in': np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0, + 'target': np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0} + + + # ### 2023/05/23, Xiaohan, 把T1模态的输入改成high-quality图像(和ground truth一致,看能否为T2提供更好的guidance)。 + # sample = {'image_in': np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0, + # 'image': np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0, + # 'target_in': np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0, + # 'target': np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0} + + if self.transform is not None: + sample = self.transform(sample) + + return sample + + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 256, 256 + crop_size = 240 + pad_size = (256-240)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + # print("img_in:", img_in.shape) + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + return {'ct_in': img_in, + 'ct': img, + 'mri_in': target_in, + 'mri': target} diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/BRATS_dataloader_new.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/BRATS_dataloader_new.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b4d0b6bc20d8bc3164b73b3d26ca09d8f1f98b --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/BRATS_dataloader_new.py @@ -0,0 +1,384 @@ +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import cv2 +import os +import torch +from torch.utils.data import Dataset +from torchvision import transforms +from .albu_transform import get_albu_transforms + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', MRIDOWN='4X', \ + SNR=15, transform=None, input_normalize=None, use_kspace=False): + + super().__init__() + self._base_dir = base_dir + self._MRIDOWN = MRIDOWN + self.im_ids = [] + self.t2_images = [] + self.t1_undermri_images, self.t2_undermri_images = [], [] + self.t1_krecon_images, self.t2_krecon_images = [], [] + self.kspace_refine = "False" # ADD + self.use_kspace = use_kspace + + self.albu_transforms = get_albu_transforms(split, (240, 240)) + + name = base_dir.rstrip("/ ").split('/')[-1] + print("base_dir=", base_dir, ", folder name =", name) + self.splits_path = base_dir.replace(name, 'cv_splits_100patients/') + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + + if SNR == 0: + t1_under_path = image_path + + if self.kspace_refine == "False": + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + elif self.kspace_refine == "True": + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_krecon') + + if self.kspace_refine == "False": + t1_krecon_path = image_path + t2_krecon_path = image_path + + # if SNR == 0: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_undermri') + t1_under_path = image_path + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + + else: + t1_under_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB') + + t1_krecon_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_krecon_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB') + + + self.t2_images.append(t2_path) + self.t1_undermri_images.append(t1_under_path) + self.t2_undermri_images.append(t2_under_path) + + self.t1_krecon_images.append(t1_krecon_path) + self.t2_krecon_images.append(t2_krecon_path) + + self.transform = transform + self.input_normalize = input_normalize + + assert (len(self.t1_images) == len(self.t2_images)) + assert (len(self.t1_images) == len(self.t1_undermri_images)) + assert (len(self.t1_images) == len(self.t2_undermri_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + + t1_in = np.array(Image.open(self._base_dir + self.t1_undermri_images[index]))/255.0 + t1 = np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0 + t1_krecon = np.array(Image.open(self._base_dir + self.t1_krecon_images[index]))/255.0 + + t2_in = np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0 + t2 = np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0 + t2_krecon = np.array(Image.open(self._base_dir + self.t2_krecon_images[index]))/255.0 + + if self.input_normalize == "mean_std": + t1_in, t1_mean, t1_std = normalize_instance(t1_in, eps=1e-11) + t1 = normalize(t1, t1_mean, t1_std, eps=1e-11) + t2_in, t2_mean, t2_std = normalize_instance(t2_in, eps=1e-11) + t2 = normalize(t2, t2_mean, t2_std, eps=1e-11) + + t1_krecon = normalize(t1_krecon, t1_mean, t1_std, eps=1e-11) + t2_krecon = normalize(t2_krecon, t2_mean, t2_std, eps=1e-11) + + ### clamp input to ensure training stability. + t1_in = np.clip(t1_in, -6, 6) + t1 = np.clip(t1, -6, 6) + t2_in = np.clip(t2_in, -6, 6) + t2 = np.clip(t2, -6, 6) + + t1_krecon = np.clip(t1_krecon, -6, 6) + t2_krecon = np.clip(t2_krecon, -6, 6) + + sample_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + elif self.input_normalize == "min_max": + t1_in = (t1_in - t1_in.min())/(t1_in.max() - t1_in.min()) + t1 = (t1 - t1.min())/(t1.max() - t1.min()) + t2_in = (t2_in - t2_in.min())/(t2_in.max() - t2_in.min()) + t2 = (t2 - t2.min())/(t2.max() - t2.min()) + sample_stats = 0 + + elif self.input_normalize == "divide": + sample_stats = 0 + + if True: #self.use_kspace: + sample = self.albu_transforms(image=t1_in, image2=t1, + image3=t2_in, image4=t2, + image5=t1_krecon, image6=t2_krecon) + + sample = {'image_in': sample['image'].astype(float), + 'image': sample['image2'].astype(float), + 'image_krecon': sample['image5'].astype(float), + 'target_in': sample['image3'].astype(float), + 'target': sample['image4'].astype(float), + 'target_krecon': sample['image6'].astype(float)} + + else: + sample = {'image_in': t1_in.astype(float), + 'image': t1.astype(float), + 'image_krecon': t1_krecon.astype(float), + 'target_in': t2_in.astype(float), + 'target': t2.astype(float), + 'target_krecon': t2_krecon.astype(float)} + + + if self.transform is not None: + sample = self.transform(sample) + + return sample, sample_stats + + + +def add_gaussian_noise(img, mean=0, std=1): + noise = std * torch.randn_like(img) + mean + noisy_img = img + noise + return torch.clamp(noisy_img, 0, 1) + + + +class AddNoise(object): + def __call__(self, sample): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + add_gauss_noise = transforms.GaussianBlur(kernel_size=5) + add_poiss_noise = transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)) + + add_noise = transforms.RandomApply([add_gauss_noise, add_poiss_noise], p=0.5) + + img_in = add_noise(img_in) + target_in = add_noise(target_in) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + + return sample + + + + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 256, 256 + crop_size = 240 + pad_size = (256-240)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_krecon = sample['image_krecon'] + target_krecon = sample['target_krecon'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + img_krecon = np.pad(img_krecon, pad_size, mode='reflect') + target_krecon = np.pad(target_krecon, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + # print("img_in:", img_in.shape) + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + img_krecon = img_krecon[ww:ww+crop_size, hh:hh+crop_size] + target_krecon = target_krecon[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'image_krecon': img_krecon, \ + 'target_in': target_in, 'target': target, 'target_krecon': target_krecon} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + + +class RandomFlip(object): + def __call__(self, sample): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + # horizontal flip + if random.random() < 0.5: + img_in = cv2.flip(img_in, 1) + img = cv2.flip(img, 1) + target_in = cv2.flip(target_in, 1) + target = cv2.flip(target, 1) + + # vertical flip + if random.random() < 0.5: + img_in = cv2.flip(img_in, 0) + img = cv2.flip(img, 0) + target_in = cv2.flip(target_in, 0) + target = cv2.flip(target, 0) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + + + +class RandomRotate(object): + def __call__(self, sample, center=None, scale=1.0): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + degrees = [0, 90, 180, 270] + angle = random.choice(degrees) + + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + + img_in = cv2.warpAffine(img_in, matrix, (w, h)) + img = cv2.warpAffine(img, matrix, (w, h)) + target_in = cv2.warpAffine(target_in, matrix, (w, h)) + target = cv2.warpAffine(target, matrix, (w, h)) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + + image_krecon = sample['image_krecon'][:, :, None].transpose((2, 0, 1)) + target_krecon = sample['target_krecon'][:, :, None].transpose((2, 0, 1)) + + # print("img_in before_numpy range:", img_in.max(), img_in.min()) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + image_krecon = torch.from_numpy(image_krecon).float() + target_krecon = torch.from_numpy(target_krecon).float() + # print("img_in range:", img_in.max(), img_in.min()) + + return {'image_in': img_in, + 'image': img, + 'target_in': target_in, + 'target': target, + 'image_krecon': image_krecon, + 'target_krecon': target_krecon} diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/BRATS_kspace_dataloader.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/BRATS_kspace_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..871a153b20eac89e45ec0025e2aa31476360fde0 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/BRATS_kspace_dataloader.py @@ -0,0 +1,298 @@ +""" +Load the low-quality and high-quality images from the BRATS dataset and transform to kspace. +""" + + +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import cv2 +import os +import torch +from torch.utils.data import Dataset + +from .kspace_subsample import undersample_mri, mri_fft + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + + return data + + + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', MRIDOWN='4X', SNR=15, transform=None, input_normalize=None): + + super().__init__() + self._base_dir = base_dir + self._MRIDOWN = MRIDOWN + self.im_ids = [] + self.t2_images = [] + self.t1_undermri_images, self.t2_undermri_images = [], [] + self.splits_path = "/data/xiaohan/BRATS_dataset/cv_splits_100patients/" + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + # self.test_file = self.splits_path + 'train_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + # test_images = os.listdir(self._base_dir) + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + if SNR == 0: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_undermri') + t1_under_path = image_path + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + else: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + t1_under_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + + self.t2_images.append(t2_path) + self.t1_undermri_images.append(t1_under_path) + self.t2_undermri_images.append(t2_under_path) + + # print("t1 images:", self.t1_images) + # print("t2 images:", self.t2_images) + # print("t1_undermri_images:", self.t1_undermri_images) + # print("t2_undermri_images:", self.t2_undermri_images) + + self.transform = transform + self.input_normalize = input_normalize + + assert (len(self.t1_images) == len(self.t2_images)) + assert (len(self.t1_images) == len(self.t1_undermri_images)) + assert (len(self.t1_images) == len(self.t2_undermri_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + + # t1_in = np.array(Image.open(self._base_dir + self.t1_undermri_images[index]))/255.0 + t1 = np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0 + # t2_in = np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0 + t2 = np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0 + # print("images:", t1_in.shape, t1.shape, t2_in.shape, t2.shape) + # print("t1 before standardization:", t1.max(), t1.min(), t1.mean()) + # print("t1 range:", t1.max(), t1.min()) + # print("t2 range:", t2.max(), t2 .min()) + + if self.input_normalize == "mean_std": + ### 对input image和target image都做(x-mean)/std的归一化操作 + t1, t1_mean, t1_std = normalize_instance(t1, eps=1e-11) + t2, t2_mean, t2_std = normalize_instance(t2, eps=1e-11) + + ### clamp input to ensure training stability. + t1 = np.clip(t1, -6, 6) + t2 = np.clip(t2, -6, 6) + # print("t1 after standardization:", t1.max(), t1.min(), t1.mean()) + + sample_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + elif self.input_normalize == "min_max": + # t1 = (t1 - t1.min())/(t1.max() - t1.min()) + # t2 = (t2 - t2.min())/(t2.max() - t2.min()) + t1 = t1/t1.max() + t2 = t2/t2.max() + sample_stats = 0 + + elif self.input_normalize == "divide": + sample_stats = 0 + + + ### convert images to kspace and perform undersampling. + # t1_kspace, t1_masked_kspace, t1_img, t1_under_img = undersample_mri(t1, _MRIDOWN = None) + t1_kspace, t1_img = mri_fft(t1) + t2_kspace, t2_masked_kspace, t2_img, t2_under_img, mask = undersample_mri(t2, _MRIDOWN = self._MRIDOWN) + + + sample = {'t1': t1_img, 't2': t2_img, 'under_t2': t2_under_img, "t2_mask": mask, \ + 't1_kspace': t1_kspace, 't2_kspace': t2_kspace, 't2_masked_kspace': t2_masked_kspace} + + if self.transform is not None: + sample = self.transform(sample) + + return sample, sample_stats + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + # print("img_in before_numpy range:", img_in.max(), img_in.min()) + img = torch.from_numpy(img).float() + target = torch.from_numpy(target).float() + # print("img_in range:", img_in.max(), img_in.min()) + + return {'ct': img, 'mri': target} + + +# class ToTensor(object): +# """Convert ndarrays in sample to Tensors.""" + +# def __call__(self, sample): +# # swap color axis because +# # numpy image: H x W x C +# # torch image: C X H X W +# img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) +# img = sample['image'][:, :, None].transpose((2, 0, 1)) +# target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) +# target = sample['target'][:, :, None].transpose((2, 0, 1)) +# # print("img_in before_numpy range:", img_in.max(), img_in.min()) +# img_in = torch.from_numpy(img_in).float() +# img = torch.from_numpy(img).float() +# target_in = torch.from_numpy(target_in).float() +# target = torch.from_numpy(target).float() +# # print("img_in range:", img_in.max(), img_in.min()) + +# return {'ct_in': img_in, +# 'ct': img, +# 'mri_in': target_in, +# 'mri': target} diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__init__.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/__init__.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93b149b0ee10586976c97405900704f9bbc8e761 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/__init__.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/albu_transform.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/albu_transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd31bcc94b6923dbaa89d202ea6d80854cc9cf0e Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/albu_transform.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/fastmri.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/fastmri.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9eceb620294195748631c3ccf0536e9cd4bb465e Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/fastmri.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/kspace_subsample.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/kspace_subsample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52d065aba8de57fe34e5eec6676cc5b090293d98 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/kspace_subsample.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/m4_utils.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/m4_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abe3cd80e93efd3567bdfaebbc448c735e14b4ff Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/m4_utils.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/m4raw_dataloader.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/m4raw_dataloader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad1d65f34cf526b1fa6a5af11ba242b76de51a3e Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/m4raw_dataloader.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/m4raw_std_dataloader.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/m4raw_std_dataloader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bd64256460da7b393f6bf4ae554d2a1f3168b43 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/m4raw_std_dataloader.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/math.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/math.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29666fefb34c64dfa6246bca490866a55c134c3d Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/math.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/subsample.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/subsample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9a80a34f5ae51d9b965bacaefa4b49dec25e432 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/subsample.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/transforms.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5f6f32866f71e658e7ff6d1d2ae100622c157dd Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/__pycache__/transforms.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/albu_transform.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/albu_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..abe5bd11f634bcc7d29a6dfe5d68319c1c2b581c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/albu_transform.py @@ -0,0 +1,75 @@ +# -*- encoding: utf-8 -*- +#Time :2022/02/24 18:14:15 +#Author :Hao Chen +#FileName :trans_lib.py +#Version :2.0 + +import cv2 +import torch +import numpy as np +import albumentations as A + + +def get_albu_transforms(type="train", img_size = (192, 192)): + if type == 'train': + compose = [ + # A.VerticalFlip(p=0.5), + # A.HorizontalFlip(p=0.5), + + A.ShiftScaleRotate(shift_limit=0.2, scale_limit=(-0.2, 0.2), + rotate_limit=5, p=0.5), + + A.OneOf([ + A.GridDistortion(num_steps=1, distort_limit=0.3, p=1.0), + A.ElasticTransform(alpha=2, sigma=5, p=1.0) + ], p=0.5), + + A.Resize(img_size[0], img_size[1])] + else: + compose = [A.Resize(img_size[0], img_size[1])] + + return A.Compose(compose, p=1.0, additional_targets={'image2': 'image', + 'image3': 'image', + 'image4': 'image', + 'image5': 'image', + 'image6': 'image', + "mask2": "mask"}) + + + + +# Beta function +def gamma_concern(img, gamma): + mean = torch.mean(img) + + img = (img - mean) * gamma + img = img + mean + img = torch.clip(img, 0, 1) + + return img + +def gamma_power(img, gamma, direction=0): + if direction == 1: + img = 1 - img + img = torch.pow(img, gamma) + + img = img / torch.max(img) + if direction == 1: + img = 1 - img + + return img + +def gamma_exp(img, gamma, direction=0): + if direction == 1: + img = 1 - img + + img = torch.exp(img * gamma) + img = img / torch.max(img) + + if direction == 1: + img = 1 - img + return img + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/brats_4X_mask.npy b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/brats_4X_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..bdf32304f95640286541ceb1068582dc69b0d60a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/brats_4X_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76341ba680a0bc9c80389e01f8511e5bd99ab361eeb48d83516904b84cccc518 +size 460928 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/brats_8X_mask.npy b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/brats_8X_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..c389e708adeb3307db90ff071599256b8f59dab5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/brats_8X_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c5160add079e8f4dc2496e5ef87c110015026d9f6116329da2238a73d8bc104 +size 230528 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/brats_data_gen.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/brats_data_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..b14c5361c534a67edc6a9fef311fce4f7f45fda4 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/brats_data_gen.py @@ -0,0 +1,302 @@ +""" +Xiaohan Xing, 2023/04/08 +对BRATS 2020数据集进行Pre-processing, 得到各个模态的under-sampled input image和2d groung-truth. +""" +import os +import argparse +import numpy as np +import nibabel as nib +from scipy import ndimage as nd +from scipy import ndimage +from skimage import filters +from skimage import io +import torch +import torch.fft +from matplotlib import pyplot as plt + +MRIDOWN=2 +SNR = 35 + + +class MaskFunc_Cartesian: + """ + MaskFunc creates a sub-sampling mask of a given shape. + The mask selects a subset of columns from the input k-space data. If the k-space data has N + columns, the mask picks out: + a) N_low_freqs = (N * center_fraction) columns in the center corresponding to + low-frequencies + b) The other columns are selected uniformly at random with a probability equal to: + prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). + This ensures that the expected number of columns selected is equal to (N / acceleration) + It is possible to use multiple center_fractions and accelerations, in which case one possible + (center_fraction, acceleration) is chosen uniformly at random each time the MaskFunc object is + called. + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there + is a 50% probability that 4-fold acceleration with 8% center fraction is selected and a 50% + probability that 8-fold acceleration with 4% center fraction is selected. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be retained. + If multiple values are provided, then one of these numbers is chosen uniformly + each time. + accelerations (List[int]): Amount of under-sampling. This should have the same length + as center_fractions. If multiple values are provided, then one of these is chosen + uniformly each time. An acceleration of 4 retains 25% of the columns, but they may + not be spaced evenly. + """ + if len(center_fractions) != len(accelerations): + raise ValueError('Number of center fractions should match number of accelerations') + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random.RandomState() + + def __call__(self, shape, seed=None): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The shape should have + at least 3 dimensions. Samples are drawn along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting the seed + ensures the same mask is generated each time for the same shape. + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError('Shape should have 3 or more dimensions') + + self.rng.seed(seed) + num_cols = shape[-2] + + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + # Create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs + 1e-10) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad:pad + num_low_freqs] = True + + # Reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + mask = mask.repeat(shape[0], 1, 1) + + return mask + + +## mri related +def mri_fourier_transform_2d(image, mask): + ''' + image: input tensor [B, H, W, C] + mask: mask tensor [H, W] + ''' + spectrum = torch.fft.fftn(image, dim=(1, 2)) + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + spectrum = torch.fft.fftshift(spectrum, dim=(1, 2)) + # Downsample k-space + spectrum = spectrum * mask[None, :, :, None] + return spectrum + +## mri related +def mri_inver_fourier_transform_2d(spectrum): + ''' + image: input tensor [B, H, W, C] + ''' + spectrum = torch.fft.ifftshift(spectrum, dim=(1, 2)) + image = torch.fft.ifftn(spectrum, dim=(1, 2)) + + return image + + +def simulate_undersample_mri(raw_mri): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + ff = MaskFunc_Cartesian([0.2], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4 + shape = [240, 240, 1] + mask = ff(shape, seed=1337) + mask = mask[:, :, 0] + # print("original MRI:", mri) + + # print("original MRI:", mri.shape) + kspace = mri_fourier_transform_2d(mri, mask) + kspace = add_gaussian_noise(kspace) + mri_recon = mri_inver_fourier_transform_2d(kspace) + kdata = torch.sqrt(kspace.real ** 2 + kspace.imag ** 2 + 1e-10) + kdata = kdata.data.numpy()[0, :, :, 0] + + under_img = torch.sqrt(mri_recon.real ** 2 + mri_recon.imag ** 2) + under_img = under_img.data.numpy()[0, :, :, 0] + + return under_img, kspace + + +def add_gaussian_noise(img, snr=15): + ### 根据SNR确定noise的放大比例 + num_pixels = img.shape[0]*img.shape[1]*img.shape[2]*img.shape[3] + psr = torch.sum(torch.abs(img.real)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + + noise_r = torch.randn_like(img.real)*np.sqrt(pnr) + + psim = torch.sum(torch.abs(img.imag)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = torch.randn_like(img.imag)*np.sqrt(pnim) + + noise = noise_r + 1j*noise_im + noise_img = img + noise + # print("original image:", img) + # print("gaussian noise:", noise) + + return noise_img + + +def complexsing_addnoise(img, snr): + ### add noise to the real part of the image. + img_numpy = img.cpu().numpy() + # print("kspace data:", img) + s_r = np.real(img_numpy) + num_pixels = s_r.shape[0]*s_r.shape[1]*s_r.shape[2]*s_r.shape[3] + psr = np.sum(np.abs(s_r)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + # print("PSR:", psr, "PNR:", pnr) + noise_r = np.random.randn(num_pixels)*np.sqrt(pnr) + + ### add noise to the iamginary part of the image. + s_im = np.imag(img_numpy) + psim = np.sum(np.abs(s_im)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = np.random.randn(num_pixels)*np.sqrt(pnim) + + noise = torch.Tensor(noise_r) + 1j*torch.Tensor(noise_im) + sn = img + noise + # print("noisy data:", sn) + # sn = torch.Tensor(sn) + + return sn + + + +def _parse(rootdir): + filetree = {} + + for sample_file in os.listdir(rootdir): + sample_dir = rootdir + sample_file + subject = sample_file + + for filename in os.listdir(sample_dir): + modality = filename.split('.').pop(0).split('_')[-1] + + if subject not in filetree: + filetree[subject] = {} + filetree[subject][modality] = filename + + return filetree + + + +def clean(rootdir, savedir, source_modality, target_modality): + filetree = _parse(rootdir) + print("filetree:", filetree) + + if not os.path.exists(savedir+'/img_norm'): + os.makedirs(savedir+'/img_norm') + + for subject, modalities in filetree.items(): + print(f'{subject}:') + + if source_modality not in modalities or target_modality not in modalities: + print('-> incomplete') + continue + + source_path = os.path.join(rootdir, subject, modalities[source_modality]) + target_path = os.path.join(rootdir, subject, modalities[target_modality]) + + source_image = nib.load(source_path) + target_image = nib.load(target_path) + + source_volume = source_image.get_fdata() + target_volume = target_image.get_fdata() + source_binary_volume = np.zeros_like(source_volume) + target_binary_volume = np.zeros_like(target_volume) + + print("source volume:", source_volume.shape) + print("target volume:", target_volume.shape) + + for i in range(source_binary_volume.shape[-1]): + source_slice = source_volume[:, :, i] + target_slice = target_volume[:, :, i] + + if source_slice.min() == source_slice.max(): + print("invalide source slice") + source_binary_volume[:, :, i] = np.zeros_like(source_slice) + else: + source_binary_volume[:, :, i] = ndimage.morphology.binary_fill_holes( + source_slice > filters.threshold_li(source_slice)) + + if target_slice.min() == target_slice.max(): + print("invalide target slice") + target_binary_volume[:, :, i] = np.zeros_like(target_slice) + else: + target_binary_volume[:, :, i] = ndimage.morphology.binary_fill_holes( + target_slice > filters.threshold_li(target_slice)) + + source_volume = np.where(source_binary_volume, source_volume, np.ones_like( + source_volume) * source_volume.min()) + target_volume = np.where(target_binary_volume, target_volume, np.ones_like( + target_volume) * target_volume.min()) + + ## resize + if source_image.header.get_zooms()[0] < 0.6: + scale = np.asarray([240, 240, source_volume.shape[-1]]) / np.asarray(source_volume.shape) + source_volume = nd.zoom(source_volume, zoom=scale, order=3, prefilter=False) + target_volume = nd.zoom(target_volume, zoom=scale, order=0, prefilter=False) + + # save volume into images + source_volume = (source_volume-source_volume.min())/(source_volume.max()-source_volume.min()) + target_volume = (target_volume-target_volume.min())/(target_volume.max()-target_volume.min()) + + for i in range(source_binary_volume.shape[-1]): + source_binary_slice = source_binary_volume[:, :, i] + target_binary_slice = target_binary_volume[:, :, i] + if source_binary_slice.max() > 0 and target_binary_slice.max() > 0: + dd = target_volume.shape[0] // 2 + target_slice = target_volume[dd - 120:dd + 120, dd - 120:dd + 120, i] + source_slice = source_volume[dd - 120:dd + 120, dd - 120:dd + 120, i] + print("source slice range:", source_slice.shape) + print("target slice range:", target_slice.max(), target_slice.min()) + # undersample MRI + source_under_img, source_kspace = simulate_undersample_mri(source_slice) + target_under_img, target_kspace = simulate_undersample_mri(target_slice) + + # # io.imsave(savedir+'/img_norm/'+subject+'_'+str(i)+'_'+source_modality+'.png', (source_slice * 255.0).astype(np.uint8)) + io.imsave(savedir + '/img_temp/' + subject + '_' + str(i) + '_' + source_modality + '_' + str(MRIDOWN) + 'X_' + str(SNR) + 'dB_undermri.png', + (source_under_img * 255.0).astype(np.uint8)) + + # io.imsave(savedir + '/img_temp/' + subject + '_' + str(i) + '_' + source_modality + '_' + str(MRIDOWN) + 'X_undermri.png', + # (source_under_img * 255.0).astype(np.uint8)) + # # io.imsave(savedir+'/img_norm/'+subject+'_'+str(i)+'_'+target_modality+'.png', (target_slice * 255.0).astype(np.uint8)) + # io.imsave(savedir + '/img_norm/' + subject + '_' + str(i) + '_' + target_modality + '_' + str(MRIDOWN) + 'X_undermri.png', + # (target_under_img * 255.0).astype(np.uint8)) + + # np.savez_compressed(rootdir + '/img_norm/' + subject + '_' + str(i) + '_' + target_modality + '_raw_'+str(MRIDOWN)+'X'+str(CTNVIEW)+'P', + # kspace=kspace, under_t1=under_img, + # t1=source_slice, ct=target_slice) + + +def main(args): + clean(args.rootdir,args.savedir, args.source, args.target) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--rootdir', type=str, default='/home/xiaohan/datasets/BRATS_dataset/BRATS_2020/') + parser.add_argument('--savedir', type=str, default='/home/xiaohan/datasets/BRATS_dataset/BRATS_2020_images/') + parser.add_argument('--source', default='t1') + parser.add_argument('--target', default='t2') + + main(parser.parse_args()) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/kspace_4_mask.npy b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/kspace_4_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..9ac77fa44d98099a5c07948465d1c0096de38828 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/kspace_4_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f68ba364235a51534884b434ac3a1c16d0cf263b9e4c08c5b3757214a6f78216 +size 2048128 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/kspace_8_mask.npy b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/kspace_8_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..f0217bbe6f62b18296c488807e2d8a90ac7f0118 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/kspace_8_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f7397d527311ac6ba09ee2621d2f964e276a16b4bf0aaded163653abef882bb +size 2048128 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/m4raw_4_mask.npy b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/m4raw_4_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..774c6d9fe7fd2df1a101f53b5bf32c2534fb20aa --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/example_mask/m4raw_4_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee522a2b8e4afa7a3349c2729effacabd9e4502be601bb176200892bded99e7f +size 6912128 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/fastmri.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/fastmri.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a65bf2deae9f9ad491f3b08e7dd02b23e7ffb7 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/fastmri.py @@ -0,0 +1,339 @@ +import csv +import os +import random +import xml.etree.ElementTree as etree +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +import pathlib + +import h5py +import numpy as np +import torch +import yaml +from torch.utils.data import Dataset +from .transforms import build_transforms +from matplotlib import pyplot as plt + +from .albu_transform import get_albu_transforms + +def fetch_dir(key, data_config_file=pathlib.Path("fastmri_dirs.yaml")): + """ + Data directory fetcher. + + This is a brute-force simple way to configure data directories for a + project. Simply overwrite the variables for `knee_path` and `brain_path` + and this function will retrieve the requested subsplit of the data for use. + + Args: + key (str): key to retrieve path from data_config_file. + data_config_file (pathlib.Path, + default=pathlib.Path("fastmri_dirs.yaml")): Default path config + file. + + Returns: + pathlib.Path: The path to the specified directory. + """ + if not data_config_file.is_file(): + default_config = dict( + knee_path="/home/jc3/Data/", + brain_path="/home/jc3/Data/", + ) + with open(data_config_file, "w") as f: + yaml.dump(default_config, f) + + raise ValueError(f"Please populate {data_config_file} with directory paths.") + + with open(data_config_file, "r") as f: + data_dir = yaml.safe_load(f)[key] + + data_dir = pathlib.Path(data_dir) + + if not data_dir.exists(): + raise ValueError(f"Path {data_dir} from {data_config_file} does not exist.") + + return data_dir + + +def et_query( + root: etree.Element, + qlist: Sequence[str], + namespace: str = "http://www.ismrm.org/ISMRMRD", +) -> str: + """ + ElementTree query function. + This can be used to query an xml document via ElementTree. It uses qlist + for nested queries. + Args: + root: Root of the xml to search through. + qlist: A list of strings for nested searches, e.g. ["Encoding", + "matrixSize"] + namespace: Optional; xml namespace to prepend query. + Returns: + The retrieved data as a string. + """ + s = "." + prefix = "ismrmrd_namespace" + + ns = {prefix: namespace} + + for el in qlist: + s = s + f"//{prefix}:{el}" + + value = root.find(s, ns) + if value is None: + raise RuntimeError("Element not found") + + return str(value.text) + + +class SliceDataset(Dataset): + def __init__( + self, + root, + transform, + challenge, + sample_rate=1, + mode='train' + ): + self.mode = mode + self.albu_transforms = get_albu_transforms(self.mode, (320, 320)) + + + # challenge + if challenge not in ("singlecoil", "multicoil"): + raise ValueError('challenge should be either "singlecoil" or "multicoil"') + self.recons_key = ( + "reconstruction_esc" if challenge == "singlecoil" else "reconstruction_rss" + ) + # transform + self.transform = transform + + self.examples = [] + + self.cur_path = root + if not os.path.exists(self.cur_path): + self.cur_path = self.cur_path + "_selected" + + self.csv_file = "knee_data_split/singlecoil_" + self.mode + "_split_less.csv" + + with open(self.csv_file, 'r') as f: + reader = csv.reader(f) + + id = 0 + + for row in reader: + pd_metadata, pd_num_slices = self._retrieve_metadata(os.path.join(self.cur_path, row[0] + '.h5')) + + pdfs_metadata, pdfs_num_slices = self._retrieve_metadata(os.path.join(self.cur_path, row[1] + '.h5')) + + for slice_id in range(min(pd_num_slices, pdfs_num_slices)): + self.examples.append( + (os.path.join(self.cur_path, row[0] + '.h5'), os.path.join(self.cur_path, row[1] + '.h5') + , slice_id, pd_metadata, pdfs_metadata, id)) + id += 1 + + if sample_rate < 1: + random.shuffle(self.examples) + num_examples = round(len(self.examples) * sample_rate) + + self.examples = self.examples[0:num_examples] + + def __len__(self): + return len(self.examples) + + def __getitem__(self, i): + + # read pd + pd_fname, pdfs_fname, slice, pd_metadata, pdfs_metadata, id = self.examples[i] + + with h5py.File(pd_fname, "r") as hf: + pd_kspace = hf["kspace"][slice] + + pd_mask = np.asarray(hf["mask"]) if "mask" in hf else None + + pd_target = hf[self.recons_key][slice] if self.recons_key in hf else None + + attrs = dict(hf.attrs) + + attrs.update(pd_metadata) + + if self.transform is None: + pd_sample = (pd_kspace, pd_mask, pd_target, attrs, pd_fname, slice) + else: + pd_sample = self.transform(pd_kspace, pd_mask, pd_target, attrs, pd_fname, slice) + + with h5py.File(pdfs_fname, "r") as hf: + pdfs_kspace = hf["kspace"][slice] + pdfs_mask = np.asarray(hf["mask"]) if "mask" in hf else None + + pdfs_target = hf[self.recons_key][slice] if self.recons_key in hf else None + + attrs = dict(hf.attrs) + + attrs.update(pdfs_metadata) + + if self.transform is None: + pdfs_sample = (pdfs_kspace, pdfs_mask, pdfs_target, attrs, pdfs_fname, slice) + else: + pdfs_sample = self.transform(pdfs_kspace, pdfs_mask, pdfs_target, attrs, pdfs_fname, slice) + + # 0: input, 1: target, 2: mean, 3: std + sample = self.albu_transforms(image=pdfs_sample[1].numpy(), + image2=pd_sample[1].numpy(), + image3=pdfs_sample[0].numpy(), + image4=pd_sample[0].numpy()) + + pdfs_sample = list(pdfs_sample) + pd_sample = list(pd_sample) + pdfs_sample[1] = sample['image'] + pd_sample[1] = sample['image2'] + pdfs_sample[0] = sample['image3'] + pd_sample[0] = sample['image4'] + + # dataset pdf mean and std tensor(3.1980e-05) tensor(1.3093e-05) + # print("dataset pdf mean and std", pdfs_sample[2], pdfs_sample[3]) + # print(pdfs_sample[1].shape, pdfs_sample[1].min(), pdfs_sample[1].max()) + + return (pd_sample, pdfs_sample, id) + + def _retrieve_metadata(self, fname): + with h5py.File(fname, "r") as hf: + et_root = etree.fromstring(hf["ismrmrd_header"][()]) + + enc = ["encoding", "encodedSpace", "matrixSize"] + enc_size = ( + int(et_query(et_root, enc + ["x"])), + int(et_query(et_root, enc + ["y"])), + int(et_query(et_root, enc + ["z"])), + ) + rec = ["encoding", "reconSpace", "matrixSize"] + recon_size = ( + int(et_query(et_root, rec + ["x"])), + int(et_query(et_root, rec + ["y"])), + int(et_query(et_root, rec + ["z"])), + ) + + lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"] + enc_limits_center = int(et_query(et_root, lims + ["center"])) + enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1 + + padding_left = enc_size[1] // 2 - enc_limits_center + padding_right = padding_left + enc_limits_max + + num_slices = hf["kspace"].shape[0] + + metadata = { + "padding_left": padding_left, + "padding_right": padding_right, + "encoding_size": enc_size, + "recon_size": recon_size, + } + + return metadata, num_slices + + +def build_dataset(args, mode='train', sample_rate=1, use_kspace=False): + assert mode in ['train', 'val', 'test'], 'unknown mode' + transforms = build_transforms(args, mode, use_kspace) + + return SliceDataset(os.path.join(args.root_path, 'singlecoil_' + mode), transforms, 'singlecoil', sample_rate=sample_rate, mode=mode) + + +if __name__ == "__main__": + ## make logger file + from torch.utils.data import DataLoader + from option import args + import time + from frequency_diffusion.degradation.k_degradation import get_ksu_kernel, apply_ksu_kernel, apply_tofre, \ + apply_to_spatial + + batch_size = 1 + db_train = build_dataset(args, mode='train') + + trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + + for i_batch, sampled_batch in enumerate(trainloader): + time2 = time.time() + # print("time for data loading:", time2 - time1) + + pd, pdfs, _ = sampled_batch + target = pdfs[1] + + mean = pdfs[2] + std = pdfs[3] + + pd_img = pd[1].unsqueeze(1) + pdfs_img = pdfs[0].unsqueeze(1) + target = target.unsqueeze(1) + + b = pd_img.size(0) + + pd_img = pd_img # [4, 1, 320, 320] + pdfs_img = pdfs_img # [4, 1, 320, 320] + target = target # [4, 1, 320, 320] + + # ----------- Degradation ------------- + num_timesteps = 1 + image_size = 320 + + # Output a list of k-space kernels + kspace_masks = get_ksu_kernel(num_timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=args.ACCELERATIONS[0], + ) # args.ACCELERATIONS = [4] or [8] + kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + + + + t = torch.randint(0, num_timesteps, (b,)).long() + mask = kspace_masks[t] + fft, mask = apply_tofre(target.clone(), mask) + # fft = fft * mask + 0.0 + pdfs_img = apply_to_spatial(fft) + pdfs_img_mask = apply_to_spatial(mask * fft)[0] + + + + + print("mask = ", mask.shape, mask.min(), mask.max()) + print("pdfs_img_mask =", pdfs_img_mask.shape) + + import matplotlib.pyplot as plt + + # combine them together + pd_img = pd_img.squeeze(1).cpu().numpy() + pdfs_img = pdfs_img.squeeze(1).cpu().numpy() + target = target.squeeze(1).cpu().numpy() + + plt.figure() + + plt.subplot(161) + plt.imshow(pd_img[0], cmap='gray') + plt.title('PD') + plt.axis('off') + plt.subplot(162) + + plt.imshow(pdfs_img_mask[0], cmap='gray') + plt.title('PDFS_mask') + plt.axis('off') + + plt.subplot(163) + plt.imshow(pdfs_img[0], cmap='gray') + plt.title('PDFS') + plt.axis('off') + + plt.subplot(164) + plt.imshow(pdfs_img_mask[0] - target[0], cmap='gray') + plt.title('Diff') + plt.axis('off') + + plt.subplot(165) + plt.imshow(target[0], cmap='gray') + plt.title('Target') + plt.axis('off') + + plt.subplot(166) + plt.imshow(pdfs_img[0] - target[0], cmap='gray')#mask[0][0], cmap='gray') + plt.title('Target') + plt.axis('off') + + plt.show() diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/hybrid_sparse.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/hybrid_sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..42e4a7e33c2204c13a1c4509897baf19e1fb07f1 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/hybrid_sparse.py @@ -0,0 +1,156 @@ +from __future__ import print_function, division +import numpy as np +from glob import glob +import random +from skimage import transform + +import torch +from torch.utils.data import Dataset + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', transform=None): + + super().__init__() + self._base_dir = base_dir + self.im_ids = [] + self.images = [] + self.gts = [] + + if split=='train': + self._image_dir = self._base_dir + imagelist = glob(self._image_dir+"/*_ct.png") + imagelist=sorted(imagelist) + for image_path in imagelist: + gt_path = image_path.replace('ct', 't1') + self.images.append(image_path) + self.gts.append(gt_path) + + + elif split=='test': + self._image_dir = self._base_dir + imagelist = glob(self._image_dir + "/*_ct.png") + imagelist=sorted(imagelist) + for image_path in imagelist: + gt_path = image_path.replace('ct', 't1') + self.images.append(image_path) + self.gts.append(gt_path) + + self.transform = transform + + assert (len(self.images) == len(self.gts)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.images))) + + def __len__(self): + return len(self.images) + + + def __getitem__(self, index): + img_in, img, target_in, target= self._make_img_gt_point_pair(index) + sample = {'image_in': img_in, 'image':img, 'target_in': target_in, 'target': target} + # print("image in:", img_in.shape) + + if self.transform is not None: + sample = self.transform(sample) + + return sample + + + def _make_img_gt_point_pair(self, index): + # Read Image and Target + + # the default setting (i.e., rawdata.npz) is 4X64P + dd = np.load(self.images[index].replace('.png', '_raw_4X64P.npz')) + # print("images range:", dd['fbp'].max(), dd['ct'].max(), dd['under_t1'].max(), dd['t1'].max()) + _img_in = dd['fbp'] + _img_in[_img_in>0.6]=0.6 + _img_in = _img_in/0.6 + + _img = dd['ct'] + _img =(_img/1000*0.192+0.192) + _img[_img<0.0]=0.0 + _img[_img>0.6]=0.6 + _img = _img/0.6 + + _target_in = dd['under_t1'] + _target = dd['t1'] + + return _img_in, _img, _target_in, _target + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 400, 400 + crop_size = 384 + pad_size = (400-384)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + return {'ct_in': img_in, + 'ct': img, + 'mri_in': target_in, + 'mri': target} diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/kspace_subsample.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/kspace_subsample.py new file mode 100644 index 0000000000000000000000000000000000000000..49da4fa5e508df325a98767e46725e93c9be0445 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/kspace_subsample.py @@ -0,0 +1,328 @@ +""" +2023/10/16, +preprocess kspace data with the undersampling mask in the fastMRI project. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +## mri related +def mri_fourier_transform_2d(image, mask): + ''' + image: input tensor [B, H, W, C] + mask: mask tensor [H, W] + ''' + spectrum = torch.fft.fftn(image, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + spectrum = torch.fft.fftshift(spectrum, dim=(1, 2)) + # Downsample k-space + masked_spectrum = spectrum * mask[None, :, :, None] + return spectrum, masked_spectrum + + +## mri related +def mri_inver_fourier_transform_2d(spectrum): + ''' + image: input tensor [B, H, W, C] + ''' + spectrum = torch.fft.ifftshift(spectrum, dim=(1, 2)) + image = torch.fft.ifftn(spectrum, dim=(1, 2), norm='ortho') + + return image + + +def add_gaussian_noise(kspace, snr): + ### 根据SNR确定noise的放大比例 + num_pixels = kspace.shape[0]*kspace.shape[1]*kspace.shape[2]*kspace.shape[3] + psr = torch.sum(torch.abs(kspace.real)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + noise_r = torch.randn_like(kspace.real)*np.sqrt(pnr) + + psim = torch.sum(torch.abs(kspace.imag)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = torch.randn_like(kspace.imag)*np.sqrt(pnim) + + noise = noise_r + 1j*noise_im + noisy_kspace = kspace + noise + + return noisy_kspace + + +def mri_fft(raw_mri, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + spectrum = torch.fft.fftn(mri, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + kspace = torch.fft.fftshift(spectrum, dim=(1, 2)) + + if _SNR > 0: + noisy_kspace = add_gaussian_noise(kspace, _SNR) + else: + noisy_kspace = kspace + + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + + return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ + kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1) + + +from dataloaders.math import complex_abs, complex_abs_numpy, complex_abs_sq + +def mri_fft_m4raw(lq_mri, hq_mri): + # breakpoint() + lq_mri = torch.tensor(lq_mri[0])[None, :, :, None].to(torch.float32) + lq_mri_spectrum = torch.fft.fftn(lq_mri, dim=(1, 2), norm='ortho') + lq_mri_spectrum = torch.fft.fftshift(lq_mri_spectrum, dim=(1, 2)) + + # Complex + lq_mri = mri_inver_fourier_transform_2d(lq_mri_spectrum[0]) + # print("lq_mri shape:", lq_mri.shape) + lq_mri = torch.cat([torch.real(lq_mri), torch.imag(lq_mri)], dim=-1) + lq_mri = complex_abs(lq_mri) + lq_mri = torch.abs(lq_mri) + # print("lq_mri after shape:", lq_mri.shape) + lq_mri = lq_mri.unsqueeze(-1) + # + lq_kspace = torch.cat([torch.real(lq_mri_spectrum), torch.imag(lq_mri_spectrum)], dim=-1) + lq_kspace = torch.abs(complex_abs(lq_kspace[0])) + lq_kspace = lq_kspace.unsqueeze(-1) + + hq_mri = torch.tensor(hq_mri[0])[None, :, :, None].to(torch.float32) + hq_mri_spectrum = torch.fft.fftn(hq_mri, dim=(1, 2), norm='ortho') + hq_mri_spectrum = torch.fft.fftshift(hq_mri_spectrum, dim=(1, 2)) + + hq_mri = mri_inver_fourier_transform_2d(hq_mri_spectrum[0]) + hq_mri = torch.cat([torch.real(hq_mri), torch.imag(hq_mri)], dim=-1) + + hq_mri = complex_abs(hq_mri) # Convert the complex number to the absolute value. + hq_mri = torch.abs(hq_mri) + hq_mri = hq_mri.unsqueeze(-1) + # + hq_kspace = torch.cat([torch.real(hq_mri_spectrum), torch.imag(hq_mri_spectrum)], dim=-1) + + hq_kspace = torch.abs(complex_abs(hq_kspace[0])) + hq_kspace = hq_kspace.unsqueeze(-1) + + # breakpoint() + return lq_kspace, lq_mri.permute(2, 0, 1), \ + hq_kspace, hq_mri.permute(2, 0, 1) + + +def undersample_mri(raw_mri, _MRIDOWN, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + if _MRIDOWN == "4X": + mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 + elif _MRIDOWN == "8X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + + ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + + shape = [240, 240, 1] + mask = ff(shape, seed=1337) + mask = mask[:, :, 0] # [1, 240] + # print("mask:", mask.shape) + # print("original MRI:", mri) + + # print("original MRI:", mri.shape) + ### under-sample the kspace data. + kspace, masked_kspace = mri_fourier_transform_2d(mri, mask) + ### add low-field noise to the kspace data. + if _SNR > 0: + noisy_kspace = add_gaussian_noise(masked_kspace, _SNR) + else: + noisy_kspace = masked_kspace + + ### conver the corrupted kspace data back to noisy MRI image. + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + + return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ + kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1), mask.unsqueeze(-1) + + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/m4_utils.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/m4_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0d76246b0c8fb757b6b589b7f889e0316696d0 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/m4_utils.py @@ -0,0 +1,272 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +import numpy as np + +def complex_mul(x, y): + """ + Complex multiplication. + + This multiplies two complex tensors assuming that they are both stored as + real arrays with the last dimension being the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == y.shape[-1] == 2 + re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] + im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] + + return torch.stack((re, im), dim=-1) + + +def complex_conj(x): + """ + Complex conjugate. + + This applies the complex conjugate assuming that the input array has the + last dimension as the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == 2 + + return torch.stack((x[..., 0], -x[..., 1]), dim=-1) + + +# def fft2c(data): +# """ +# Apply centered 2 dimensional Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The FFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# data = torch.fft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + +# def ifft2c(data): +# """ +# Apply centered 2-dimensional Inverse Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The IFFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# # data = torch.ifft(data, 2, normalized=True) +# data = torch.fft.ifft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + + +def fft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2 dimensional Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.fft``. + + Returns: + The FFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.fftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.ifft``. + + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + + + + +def complex_abs(data): + """ + Compute the absolute value of a complex valued input tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Absolute value of data. + """ + assert data.size(-1) == 2 + + return (data ** 2).sum(dim=-1).sqrt() + + + +def complex_abs_numpy(data): + assert data.shape[-1] == 2 + + return np.sqrt(np.sum(data ** 2, axis=-1)) + + +def complex_abs_sq(data):#multi coil + """ + Compute the squared absolute value of a complex tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Squared absolute value of data. + """ + assert data.size(-1) == 2 + return (data ** 2).sum(dim=-1) + + +# Helper functions + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data + """ + data = data.numpy() + return data[..., 0] + 1j * data[..., 1] diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/m4raw_dataloader.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/m4raw_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..6f676b25df3a18129bea4a7bf102d7844bca332c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/m4raw_dataloader.py @@ -0,0 +1,574 @@ + +from __future__ import print_function, division +from typing import Dict, NamedTuple, Optional, Sequence, Tuple, Union +import sys +sys.path.append('.') +from glob import glob +import os, time +os.environ['OPENBLAS_NUM_THREADS'] = '1' +import numpy as np +import torch +from torch.utils.data import Dataset + +import h5py +from matplotlib import pyplot as plt +from dataloaders.math import ifft2c, fft2c, complex_abs +from dataloaders.kspace_subsample import create_mask_for_mask_type + +import argparse +from torch.utils.data import DataLoader +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity +from dataloaders.kspace_subsample import undersample_mri, mri_fft, mri_fft_m4raw +from tqdm import tqdm + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + +def normal(x): + y = np.zeros_like(x) + for i in range(y.shape[0]): + x_min = x[i].min() + x_max = x[i].max() + y[i] = (x[i] - x_min)/(x_max-x_min) + return y + + + +def undersample_mri(kspace, _MRIDOWN): + # print("kspace shape:", kspace.shape) ## [18, 4, 256, 256, 2] + + if _MRIDOWN == "4X": + mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 + elif _MRIDOWN == "8X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + + ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + + shape = [256, 256, 1] + mask = ff(shape, seed=1337) ## [1, 256, 1] + mask = mask[:, :, 0] # [1, 256] + + masked_kspace = kspace * mask[None, None, :, :, None] + + return masked_kspace, mask.unsqueeze(-1) + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Args: + data (np.array): Input numpy array. + + Returns: + torch.Tensor: PyTorch version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data) + +def rss(data, dim=0): + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Args: + data (torch.Tensor): The input tensor + dim (int): The dimensions along which to apply the RSS transform + + Returns: + torch.Tensor: The RSS value. + """ + return torch.sqrt((data ** 2).sum(dim)) + +def read_h5(file_name, _MRIDOWN='None', use_kspace=False): + hf = h5py.File(file_name) + volume_kspace = hf['kspace'][()] + slice_kspace = volume_kspace + slice_kspace2 = to_tensor(slice_kspace) + + slice_image = ifft2c(slice_kspace2) + slice_image_abs = complex_abs(slice_image) + slice_image_rss = rss(slice_image_abs, dim=1) + slice_image_rss = np.abs(slice_image_rss.numpy()) + slice_image_rss = normal(slice_image_rss) + + if _MRIDOWN == 'None' or use_kspace: + masked_image_rss = slice_image_rss + + else: + # print("Undersample MRI") + # Undersample MRI + masked_kspace, mask = undersample_mri(slice_kspace2, _MRIDOWN) # Masked + + masked_image = ifft2c(masked_kspace) + masked_image_abs = complex_abs(masked_image) + masked_image_rss = rss(masked_image_abs, dim=1) + masked_image_rss = np.abs(masked_image_rss.numpy()) + masked_image_rss = normal(masked_image_rss) + + return slice_image_rss, masked_image_rss + + +DEBUG = True + +class M4Raw_TrainSet(Dataset): + def __init__(self, root_path, MRIDOWN, kspace_refine='False', use_kspace=False): + + self.use_kspace = use_kspace + self.kspace_refine = kspace_refine + start_time = time.time() + + input_list1 = sorted(glob(os.path.join(root_path + '/multicoil_train' + '/*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(root_path + '/multicoil_train' +'/*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + T2_input_list = [input_list1, input_list2, input_list3] + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 256, 256]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 256, 256]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 256, 256]) + + """ + 读取kspace network重建的图像 + """ + if kspace_refine == 'True': + krecon_list1 = sorted(glob(os.path.join(root_path + 'multicoil_train' + '/*_T102_recon_kspace_round2_images.npy'))) + krecon_list2 = [path.replace('_T102','_T101') for path in krecon_list1] + krecon_list3 = [path.replace('_T102','_T103') for path in krecon_list1] + T1_krecon_list = [krecon_list1, krecon_list2, krecon_list3] + + krecon_list1 = sorted(glob(os.path.join(root_path + 'multicoil_train' + '/*_T202_recon_kspace_round2_images.npy'))) + krecon_list2 = [path.replace('_T202','_T201') for path in krecon_list1] + krecon_list3 = [path.replace('_T202','_T203') for path in krecon_list1] + T2_krecon_list = [krecon_list1, krecon_list2, krecon_list3] + + self.T1_krecon_list = T1_krecon_list + self.T2_krecon_list = T2_krecon_list + + self.T1_krecon = np.zeros([len(input_list1), len(T1_krecon_list), 18, 240, 240]).astype(np.float32) + self.T2_krecon = np.zeros([len(input_list2), len(T2_krecon_list), 18, 240, 240]).astype(np.float32) + + + + print('TrainSet loading...') + for i in tqdm(range(len(self.T1_input_list))): + for j, path in enumerate(T1_input_list[i]): + self.T1_images[j][i], _ = read_h5(path, use_kspace=use_kspace) + # self.fname_slices[i].append(path) # each coil + + if kspace_refine == 'True': + for k, path in enumerate(T1_krecon_list[i]): + self.T1_krecon[k][i] = np.load(path).astype(np.float32)/255.0 + self.T1_labels = np.mean(self.T1_images, axis=1) # multi-coil mean + + for i in tqdm(range(len(self.T2_input_list))): + for j, path in enumerate(T2_input_list[i]): + self.T2_images[j][i], self.T2_masked_images[j][i] = read_h5(path, _MRIDOWN=MRIDOWN, use_kspace=use_kspace) + if kspace_refine == 'True': + for k, path in enumerate(T2_krecon_list[i]): + self.T2_krecon[k][i] = np.load(path).astype(np.float32)/255.0 + + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print(f'Finish loading with time = {time.time() - start_time}s') + + # print("T1 image original shape:", self.T1_images.shape) # T1 image original shape: (128, 3, 18, 256, 256) + # print("T2 image original shape:", self.T2_images.shape) + + N, _, S, H, W = self.T1_images.shape + self.fname_slices = [] + + for i in range(N): + for j in range(S): + self.fname_slices.append((i, j)) + # print(f'nan value at {i}, {j}, {k}, {l}') + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),256,256)[:, :, 8:248, 8:248] + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),256,256)[:, :, 8:248, 8:248] + self.T1_labels = self.T1_labels.reshape(-1,1,256,256)[:, :, 8:248, 8:248] + self.T2_labels = self.T2_labels.reshape(-1,1,256,256)[:, :, 8:248, 8:248] + + # Train data shape: (2304, 3, 240, 240) + + # T1 N, 3, 240, 240 + + if kspace_refine == 'True': + self.T1_krecon = self.T1_krecon.transpose(0,2,1,3,4).reshape(-1,len(T1_krecon_list),240,240) + self.T2_krecon = self.T2_krecon.transpose(0,2,1,3,4).reshape(-1,len(T2_krecon_list),240,240) + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + T1_images = self.T1_images[idx] # lq_mri + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] # gt_mri + T2_labels = self.T2_labels[idx] + + fname = self.fname_slices[idx][0] + slice = self.fname_slices[idx][1] + + ## 每次都是从三个repetition中选择一个作为input. + choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) + T1_images = T1_images[choices] + T2_images = T2_images[choices] + + t1_kspace_in, t1_in, t1_kspace, t1_img = mri_fft_m4raw(T1_images, T1_labels) + t2_kspace_in, t2_in, t2_kspace, t2_img = mri_fft_m4raw(T2_images, T2_labels) + + + # normalize + t1_img, t1_mean, t1_std = normalize_instance(t1_img) + t1_in = normalize(t1_in, t1_mean, t1_std) + # t1_mean = 0 + # t1_std = 1 + + t2_img, t2_mean, t2_std = normalize_instance(t2_img) + t2_in = normalize(t2_in, t2_mean, t2_std) + + + + # filter value that greater or less than 6 + t1_img = torch.clamp(t1_img, -6, 6) + t2_img = torch.clamp(t2_img, -6, 6) + t1_in = torch.clamp(t1_in, -6, 6) + t2_in = torch.clamp(t2_in, -6, 6) + + + # t2_mean = 0 + # t2_std = 1 + + # t1_img: torch.Size([2, 240, 240]) torch.float32 tensor(0.9775) tensor(-9.0143e-08) + # t2_img: torch.Size([1, 240, 240]) torch.float32 tensor(22.1929) tensor(-0.3244) + + # t1_img: torch.Size([1, 240, 240]) torch.float32 tensor(5.1340) tensor(1.7756e-06) + # t2_img: torch.Size([1, 240, 240]) torch.float32 tensor(4.4957) tensor(2.8719e-05) + # t1_in: torch.Size([1, 240, 240]) torch.float32 tensor(5.2390) tensor(0.0003) + # t2_in: torch.Size([1, 240, 240]) torch.float32 tensor(4.7321) tensor(4.5622e-05) + + # print("t1_img:", t1_img.shape, t1_img.dtype, t1_img.max(), t1_img.min()) + # print("t2_img:", t2_img.shape, t2_img.dtype, t2_img.max(), t2_img.min()) + # print("t1_in:", t1_in.shape, t1_in.dtype, t1_in.max(), t1_in.min()) + # print("t2_in:", t2_in.shape, t2_in.dtype, t2_in.max(), t2_in.min()) # t1_img: torch.Size([1, 240, 240]) torch.float32 tensor(20.5561) tensor(-0.2671) + # print() + + # How to get mean and std of the training data? + # fname, slice + sample = { + 'fname': fname, + 'slice': slice, + + 'ref_kspace_full': t1_kspace, + 'ref_kspace_sub': t1_kspace_in, + 'ref_image_full': t1_img, + 'ref_image_sub': t1_in, + 't1_mean': t1_mean, + 't1_std': t1_std, + + 'tag_kspace_full': t2_kspace, + 'tag_kspace_sub': t2_kspace_in, + 'tag_image_full': t2_img, + 'tag_image_sub': t2_in, + 't2_mean': t2_mean, + 't2_std': t2_std, + + } + + return sample + + + +class M4Raw_TestSet(Dataset): + def __init__(self, root_path, MRIDOWN, kspace_refine='False', use_kspace=False): + + self.kspace_refine = kspace_refine + + input_list1 = sorted(glob(os.path.join(root_path + '/multicoil_val' + '/*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(root_path + '/multicoil_val' + '/*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + T2_input_list = [input_list1,input_list2,input_list3] + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 256, 256]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 256, 256]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 256, 256]) + + """ + 读取kspace network重建的图像 + """ + if kspace_refine == 'True': + krecon_list1 = sorted(glob(os.path.join(root_path + 'multicoil_val' + '/*_T102_recon_kspace_round2_images.npy'))) + krecon_list2 = [path.replace('_T102','_T101') for path in krecon_list1] + krecon_list3 = [path.replace('_T102','_T103') for path in krecon_list1] + T1_krecon_list = [krecon_list1, krecon_list2, krecon_list3] + + krecon_list1 = sorted(glob(os.path.join(root_path + 'multicoil_val' + '/*_T202_recon_kspace_round2_images.npy'))) + krecon_list2 = [path.replace('_T202','_T201') for path in krecon_list1] + krecon_list3 = [path.replace('_T202','_T203') for path in krecon_list1] + T2_krecon_list = [krecon_list1, krecon_list2, krecon_list3] + + self.T1_krecon_list = T1_krecon_list + self.T2_krecon_list = T2_krecon_list + + self.T1_krecon = np.zeros([len(input_list1), len(T1_krecon_list), 18, 240, 240]).astype(np.float32) + self.T2_krecon = np.zeros([len(input_list2), len(T2_krecon_list), 18, 240, 240]).astype(np.float32) + + + print('TestSet loading...') + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + self.T1_images[j][i], _ = read_h5(path, use_kspace=use_kspace) + + if kspace_refine == 'True': + for k, path in enumerate(T1_krecon_list[i]): + self.T1_krecon[k][i] = np.load(path).astype(np.float32)/255.0 + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_images[j][i], self.T2_masked_images[j][i] = read_h5(path, _MRIDOWN = MRIDOWN, use_kspace=use_kspace) + + if kspace_refine == 'True': + for k, path in enumerate(T2_krecon_list[i]): + self.T2_krecon[k][i] = np.load(path).astype(np.float32)/255.0 + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print('Finish loading') + N, _, S, H, W = self.T1_images.shape + self.fname_slices = [] + + for i in range(N): + for j in range(S): + self.fname_slices.append((i, j)) + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),256,256)[:, :, 8:248, 8:248] + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),256,256)[:, :, 8:248, 8:248] + self.T1_labels = self.T1_labels.reshape(-1,1,256,256)[:, :, 8:248, 8:248] + self.T2_labels = self.T2_labels.reshape(-1,1,256,256)[:, :, 8:248, 8:248] + print("Test data shape:", self.T1_images.shape) + + if kspace_refine == 'True': + self.T1_krecon = self.T1_krecon.transpose(0,2,1,3,4).reshape(-1,len(T1_krecon_list),240,240) + self.T2_krecon = self.T2_krecon.transpose(0,2,1,3,4).reshape(-1,len(T2_krecon_list),240,240) + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + + # print("T1_labels:", T1_labels.shape, T1_labels.dtype, T1_labels.max(), T1_labels.min()) + # print("T2_labels:", T2_labels.shape, T2_labels.dtype, T2_labels.max(), T2_labels.min()) + + choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + T1_images = T1_images[choices] + T2_images = T2_images[choices] + + t1_kspace_in, t1_in, t1_kspace, t1_img = mri_fft_m4raw(T1_images, T1_labels) + t2_kspace_in, t2_in, t2_kspace, t2_img = mri_fft_m4raw(T2_images, T2_labels) + + fname = self.fname_slices[idx][0] + slice = self.fname_slices[idx][1] + + # normalize + t1_img, t1_mean, t1_std = normalize_instance(t1_img) + t1_in = normalize(t1_in, t1_mean, t1_std) + # t1_mean = 0 + # t1_std = 1 + + t2_img, t2_mean, t2_std = normalize_instance(t2_img) + t2_in = normalize(t2_in, t2_mean, t2_std) + + # filter value that greater or less than 6 + t1_img = torch.clamp(t1_img, -6, 6) + t2_img = torch.clamp(t2_img, -6, 6) + t1_in = torch.clamp(t1_in, -6, 6) + t2_in = torch.clamp(t2_in, -6, 6) + + + # print("t1_img:", t1_img.shape, t1_img.dtype, t1_img.max(), t1_img.min()) + # print("in dataset t2_img:", t2_img.shape, t2_img.dtype, t2_img.max(), t2_img.min()) + # print("t1_in:", t1_in.shape, t1_in.dtype, t1_in.max(), t1_in.min()) + # print("t2_in:", t2_in.shape, t2_in.dtype, t2_in.max(), t2_in.min()) # t1_img: torch.Size([1, 240, 240]) torch.float32 tensor(20.5561) tensor(-0.2671) + # print() + + # fname, slice + sample = { + 'fname': fname, + 'slice': slice, + + 'ref_kspace_full': t1_kspace, + 'ref_kspace_sub': t1_kspace_in, + 'ref_image_full': t1_img, + 'ref_image_sub': t1_in, + 't1_mean': t1_mean, + 't1_std': t1_std, + + 'tag_kspace_full': t2_kspace, + 'tag_kspace_sub': t2_kspace_in, + 'tag_image_full': t2_img, + 'tag_image_sub': t2_in, + 't2_mean': t2_mean, + 't2_std': t2_std, + + } + + return sample + + +def compute_metrics(image, labels): + MSE = mean_squared_error(labels, image)/np.var(labels) + PSNR = peak_signal_noise_ratio(labels, image) + SSIM = structural_similarity(labels, image) + + # print("metrics:", MSE, PSNR, SSIM) + + return MSE, PSNR, SSIM + +def complex_abs_eval(data): + return (data[0:1, :, :] ** 2 + data[1:2, :, :] ** 2).sqrt() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--root_path', type=str, default='/data/qic99/MRI_recon/') + parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') + parser.add_argument('--kspace_refine', type=str, default='False', help='whether use the image reconstructed from kspace network.') + args = parser.parse_args() + + db_test = M4Raw_TestSet(args) + testloader = DataLoader(db_test, batch_size=4, shuffle=False, num_workers=4, pin_memory=True) + + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all = [], [], [] + + save_dir = "./visualize_images/" + + for i_batch, sampled_batch in enumerate(testloader): + # t1_in, t2_in = sampled_batch['t1_in'].cuda(), sampled_batch['t2_in'].cuda() + # t1, t2 = sampled_batch['t1_labels'].cuda(), sampled_batch['t2_labels'].cuda() + + t1_in, t2_in = sampled_batch['ref_image_sub'].cuda(), sampled_batch['tag_image_sub'].cuda() + t1, t2 = sampled_batch['ref_image_full'].cuda(), sampled_batch['tag_image_full'].cuda() + + # breakpoint() + for j in range(t1_in.shape[0]): + # t1_in_img = (np.clip(complex_abs_eval(t1_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(complex_abs_eval(t1[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(complex_abs_eval(t2_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(complex_abs_eval(t2[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # breakpoint() + t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + # t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + + # print(t1_in_img.shape, t1_img.shape) + t1_MSE, t1_PSNR, t1_SSIM = compute_metrics(t1_in_img, t1_img) + t2_MSE, t2_PSNR, t2_SSIM = compute_metrics(t2_in_img, t2_img) + + + t1_MSE_all.append(t1_MSE) + t1_PSNR_all.append(t1_PSNR) + t1_SSIM_all.append(t1_SSIM) + + t2_MSE_all.append(t2_MSE) + t2_PSNR_all.append(t2_PSNR) + t2_SSIM_all.append(t2_SSIM) + + + # print("t1_PSNR:", t1_PSNR_all) + print("t1_PSNR:", round(np.array(t1_PSNR_all).mean(), 4)) + print("t1_NMSE:", round(np.array(t1_MSE_all).mean(), 4)) + print("t1_SSIM:", round(np.array(t1_SSIM_all).mean(), 4)) + + print("t2_PSNR:", round(np.array(t2_PSNR_all).mean(), 4)) + print("t2_NMSE:", round(np.array(t2_MSE_all).mean(), 4)) + print("t2_SSIM:", round(np.array(t2_SSIM_all).mean(), 4)) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/m4raw_std_dataloader.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/m4raw_std_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..536d401a4ae22bc11b4cc3f40fa697ba8699295a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/m4raw_std_dataloader.py @@ -0,0 +1,583 @@ + +from __future__ import print_function, division +from typing import Dict, NamedTuple, Optional, Sequence, Tuple, Union +import sys +sys.path.append('.') +from glob import glob +import os +os.environ['OPENBLAS_NUM_THREADS'] = '1' +import numpy as np +import torch +from torch.utils.data import Dataset + +import h5py +from matplotlib import pyplot as plt +from dataloaders.m4_utils import ifft2c, fft2c, complex_abs +from dataloaders.kspace_subsample import create_mask_for_mask_type + +from .albu_transform import get_albu_transforms + +import argparse, time +from torch.utils.data import DataLoader +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + +def normalize_instance_dim(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean(dim=(1, 2, 3), keepdim=True) # B, C, H, W + std = data.std(dim=(1, 2, 3), keepdim=True) + + return normalize(data, mean, std, eps), mean, std + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + +def undersample_mri(kspace, _MRIDOWN): + # print("kspace shape:", kspace.shape) ## [18, 4, 256, 256, 2] + if _MRIDOWN == "4X": + mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 + elif _MRIDOWN == "8X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + elif _MRIDOWN == "12X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.03, 12 + + ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 f| ------------------------------------------------------------------------------------------------------------------------------- + + shape = [256, 256, 1] + mask = ff(shape, seed=1337) ## [1, 256, 1] + + mask = mask[:, :, 0] # [1, 256] + + masked_kspace = kspace * mask[None, None, :, :, None] + + return masked_kspace, mask.unsqueeze(-1) + + +def apply_mask(data, mask_func, seed=None, padding=None): + """ + Subsample given k-space by multiplying with a mask. + + Args: + data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where + dimensions -3 and -2 are the spatial dimensions, and the final dimension has size + 2 (for complex values). + mask_func (callable): A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed (int or 1-d array_like, optional): Seed for the random number generator. + + Returns: + (tuple): tuple containing: + masked data (torch.Tensor): Subsampled k-space data + mask (torch.Tensor): The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + +def complex_center_crop(data, shape): + """ + Apply a center crop to the input image or batch of complex images. + + Args: + data (torch.Tensor): The complex input tensor to be center cropped. It + should have at least 3 dimensions and the cropping is applied along + dimensions -3 and -2 and the last dimensions should have a size of + 2. + shape (int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image + """ + assert 0 < shape[0] <= data.shape[-3] + assert 0 < shape[1] <= data.shape[-2] + + w_from = (data.shape[-3] - shape[0]) // 2 #80 + h_from = (data.shape[-2] - shape[1]) // 2 #80 + w_to = w_from + shape[0] #240 + h_to = h_from + shape[1] #240 + + return data[..., w_from:w_to, h_from:h_to, :] + + + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Args: + data (np.array): Input numpy array. + + Returns: + torch.Tensor: PyTorch version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data) + +def rss(data, dim=0): + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Args: + data (torch.Tensor): The input tensor + dim (int): The dimensions along which to apply the RSS transform + + Returns: + torch.Tensor: The RSS value. + """ + return torch.sqrt((data ** 2).sum(dim)) + +def read_h5(file_name, _MRIDOWN, use_kspace): + crop_size=[240,240] + + hf = h5py.File(file_name) + volume_kspace = hf['kspace'][()] + + slice_kspace = volume_kspace + slice_kspace = to_tensor(slice_kspace) + import imageio as io + + target = ifft2c(slice_kspace) + target = complex_center_crop(target, crop_size) + target = complex_abs(target) + target = rss(target, dim=1) + + if not use_kspace: + # print("use_kspace is False") + + # masked_kspace, mask = apply_mask(slice_kspace, mask_func, seed=123456) + masked_kspace, mask = undersample_mri(slice_kspace, _MRIDOWN) + lq_image = ifft2c(masked_kspace) + lq_image = complex_center_crop(lq_image, crop_size) + lq_image = complex_abs(lq_image) + lq_image = rss(lq_image, dim=1) + + else: + lq_image = target + + lq_image_list=[] + mean_list=[] + std_list=[] + for i in range(lq_image.shape[0]): + image, mean, std = normalize_instance(lq_image[i], eps=1e-11) + image = image.clamp(-6, 6) + lq_image_list.append(image) + mean_list.append(mean) + std_list.append(std) + + + + target_list=[] + + for i in range(lq_image.shape[0]): + target_slice = normalize(target[i], mean_list[i], std_list[i], eps=1e-11) + target_slice = target_slice.clamp(-6, 6) + target_list.append(target_slice) + + return torch.stack(lq_image_list), torch.stack(target_list), torch.stack(mean_list), torch.stack(std_list) + + + + + +class M4Raw_TrainSet(Dataset): + def __init__(self, args, use_kspace=False, DEBUG=False): + mask_func = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + start_time = time.time() + + self.albu_transforms = get_albu_transforms("train", (240, 240)) + + self._MRIDOWN = args.MRIDOWN + self.input_normalize = args.input_normalize + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_train', '*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_train', '*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + + T2_input_list = [input_list1, input_list2, input_list3] + + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 240, 240]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_mean = np.zeros([len(input_list2),len(T2_input_list), 18]) + self.T2_std = np.zeros([len(input_list2),len(T2_input_list), 18]) + + print('TrainSet loading...') + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + _, self.T1_images[j][i], _, _ = read_h5(path, self._MRIDOWN, use_kspace=use_kspace) + + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_masked_images[j][i], self.T2_images[j][i], self.T2_mean[j][i], self.T2_std[j][i] = read_h5(path, self._MRIDOWN, use_kspace=use_kspace) + # lq_image_list, target_list + + + # self.T2_labels = np.mean(self.T2_images, axis=1) # TODO + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print(f'Finish loading with time = {time.time() - start_time}s') + + # print("T1 image original shape:", self.T1_images.shape) # T1 image original shape: (128, 3, 18, 256, 256) + # print("T2 image original shape:", self.T2_images.shape) + + N, _, S, H, W = self.T1_images.shape + self.fname_slices = [] + + for i in range(N): + for j in range(S): + self.fname_slices.append((i, j)) + + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),240,240) + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),240,240) + self.T1_labels = self.T1_labels.reshape(-1,1,240,240) + self.T2_labels = self.T2_labels.reshape(-1,1,240,240) + self.T2_mean = self.T2_mean.reshape(-1,3) + self.T2_std = self.T2_std.reshape(-1,3) + + print("Train data shape:", self.T1_images.shape) + # breakpoint() + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + + fname = self.fname_slices[idx][0] + slice = self.fname_slices[idx][1] + + + choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) ## 每次都是从三个repetition中选择一个作为input. + # choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + T1_images = T1_images[choices] + T2_images = T2_images[choices] + t2_mean = self.T2_mean[idx][choices] + t2_std = self.T2_std[idx][choices] + + # (1, 240, 240) + sample = self.albu_transforms(image=T1_images[0], image2=T2_images[0], + image3=T1_labels[0], image4=T2_labels[0]) + + # breakpoint() + t1_in = np.expand_dims(sample['image'], 0) + t2_in = np.expand_dims(sample['image2'], 0) + t1 = np.expand_dims(sample['image3'], 0) + t2 = np.expand_dims(sample['image4'], 0) + + # sample_stats = {"t2_mean": t2_mean, "t2_std": t2_std} + + # print("t1_in shape:", t1_in.shape, "t1 shape:", t1.shape, "t2_in shape:", t2_in.shape, "t2 shape:", t2.shape) + + # breakpoint() + sample = { + 'fname': fname, + 'slice': slice, + + 't1_in': t1_in.astype(np.float32), + 't1': t1.astype(np.float32), + "t2_mean": t2_mean, "t2_std": t2_std, + + 't2_in': t2_in.astype(np.float32), + 't2': t2.astype(np.float32)} + + return sample #, sample_stats + + + +class M4Raw_TestSet(Dataset): + def __init__(self, args, use_kspace=False, DEBUG=False): + mask_func = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + self.use_kspace = use_kspace + self._MRIDOWN = args.MRIDOWN + + self.input_normalize = args.input_normalize + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_val' + '*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_val', '*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + + T2_input_list = [input_list1,input_list2,input_list3] + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 240, 240]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_mean = np.zeros([len(input_list2),len(T2_input_list), 18]) + self.T2_std = np.zeros([len(input_list2),len(T2_input_list), 18]) + print('TestSet loading...') + + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + _, self.T1_images[j][i], _, _ = read_h5(path, self._MRIDOWN, use_kspace=use_kspace) + + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_masked_images[j][i], self.T2_images[j][i], self.T2_mean[j][i], self.T2_std[j][i] = read_h5(path, + self._MRIDOWN, + use_kspace=use_kspace) + # lq_image_list, target_list + + # self.T2_labels = np.mean(self.T2_images, axis=1) # TODO + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print('Finish loading') + N, _, S, H, W = self.T1_images.shape + self.fname_slices = [] + + for i in range(N): + for j in range(S): + self.fname_slices.append((i, j)) + + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),240,240) + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),240,240) + self.T1_labels = self.T1_labels.reshape(-1,1,240,240) + self.T2_labels = self.T2_labels.reshape(-1,1,240,240) + self.T2_mean = self.T2_mean.reshape(-1,3) + self.T2_std = self.T2_std.reshape(-1,3) + print("Train data shape:", self.T1_images.shape) + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + + + fname = self.fname_slices[idx][0] + slice = self.fname_slices[idx][1] + + + + # choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) ## 每次都是从三个repetition中选择一个作为input. + choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + + + T1_images = T1_images[choices] + T2_images = T2_images[choices] + + t2_mean = self.T2_mean[idx][choices] + t2_std = self.T2_std[idx][choices] + + + # breakpoint() + # import imageio as io + # io.imsave('T1_images.png', (T1_images[0]*255).astype(np.uint8)) + # io.imsave('T1_labels.png', (T1_labels[0]*255).astype(np.uint8)) + # io.imsave('T2_images.png', (T2_images[0]*255).astype(np.uint8)) + # io.imsave('T2_labels.png', (T2_labels[0]*255).astype(np.uint8)) + # breakpoint() + + t1_in = T1_images + t1 = T1_labels + t2_in = T2_images + t2 = T2_labels + + # sample_stats = {"t2_mean": t2_mean, "t2_std": t2_std} + + # print("Test t1_in shape:", t1_in.shape, "t1 shape:", t1.shape, "t2_in shape:", t2_in.shape, "t2 shape:", t2.shape) + + # breakpoint() + sample = { + 'fname': fname, + 'slice': slice, + + 't1_in': t1_in.astype(np.float32), + 't1': t1.astype(np.float32), + "t2_mean": t2_mean, "t2_std": t2_std, + + 't2_in': t2_in.astype(np.float32), + 't2': t2.astype(np.float32)} + + return sample #, sample_stats + + + +def compute_metrics(image, labels): + MSE = mean_squared_error(labels, image)/np.var(labels) + PSNR = peak_signal_noise_ratio(labels, image) + SSIM = structural_similarity(labels, image) + + # print("metrics:", MSE, PSNR, SSIM) + + return MSE, PSNR, SSIM + +def complex_abs_eval(data): + return (data[0:1, :, :] ** 2 + data[1:2, :, :] ** 2).sqrt() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--root_path', type=str, default='/data/qic99/MRI_recon/') + parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') + parser.add_argument('--kspace_refine', type=str, default='False', help='whether use the image reconstructed from kspace network.') + args = parser.parse_args() + + db_test = M4Raw_TestSet(args) + testloader = DataLoader(db_test, batch_size=4, shuffle=False, num_workers=4, pin_memory=True) + + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all = [], [], [] + + save_dir = "./visualize_images/" + + for i_batch, sampled_batch in enumerate(testloader): + # t1_in, t2_in = sampled_batch['t1_in'].cuda(), sampled_batch['t2_in'].cuda() + # t1, t2 = sampled_batch['t1_labels'].cuda(), sampled_batch['t2_labels'].cuda() + + t1_in, t2_in = sampled_batch['ref_image_sub'].cuda(), sampled_batch['tag_image_sub'].cuda() + t1, t2 = sampled_batch['ref_image_full'].cuda(), sampled_batch['tag_image_full'].cuda() + + # breakpoint() + for j in range(t1_in.shape[0]): + # t1_in_img = (np.clip(complex_abs_eval(t1_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(complex_abs_eval(t1[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(complex_abs_eval(t2_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(complex_abs_eval(t2[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # breakpoint() + t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + # t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + + # print(t1_in_img.shape, t1_img.shape) + t1_MSE, t1_PSNR, t1_SSIM = compute_metrics(t1_in_img, t1_img) + t2_MSE, t2_PSNR, t2_SSIM = compute_metrics(t2_in_img, t2_img) + + + t1_MSE_all.append(t1_MSE) + t1_PSNR_all.append(t1_PSNR) + t1_SSIM_all.append(t1_SSIM) + + t2_MSE_all.append(t2_MSE) + t2_PSNR_all.append(t2_PSNR) + t2_SSIM_all.append(t2_SSIM) + + + # print("t1_PSNR:", t1_PSNR_all) + print("t1_PSNR:", round(np.array(t1_PSNR_all).mean(), 4)) + print("t1_NMSE:", round(np.array(t1_MSE_all).mean(), 4)) + print("t1_SSIM:", round(np.array(t1_SSIM_all).mean(), 4)) + + print("t2_PSNR:", round(np.array(t2_PSNR_all).mean(), 4)) + print("t2_NMSE:", round(np.array(t2_MSE_all).mean(), 4)) + print("t2_SSIM:", round(np.array(t2_SSIM_all).mean(), 4)) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/math.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/math.py new file mode 100644 index 0000000000000000000000000000000000000000..120b9f0501b1ef187228e2650413d675b307d1cb --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/math.py @@ -0,0 +1,231 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +import numpy as np + +def complex_mul(x, y): + """ + Complex multiplication. + + This multiplies two complex tensors assuming that they are both stored as + real arrays with the last dimension being the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == y.shape[-1] == 2 + re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] + im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] + + return torch.stack((re, im), dim=-1) + + +def complex_conj(x): + """ + Complex conjugate. + + This applies the complex conjugate assuming that the input array has the + last dimension as the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == 2 + + return torch.stack((x[..., 0], -x[..., 1]), dim=-1) + + + + + +def fft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2 dimensional Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.fft``. + + Returns: + The FFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.fftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.ifft``. + + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + + + + +def complex_abs(data): + """ + Compute the absolute value of a complex valued input tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Absolute value of data. + """ + assert data.size(-1) == 2 + + return (data ** 2).sum(dim=-1).sqrt() + + + +def complex_abs_numpy(data): + assert data.shape[-1] == 2 + + return np.sqrt(np.sum(data ** 2, axis=-1)) + + +def complex_abs_sq(data):#multi coil + """ + Compute the squared absolute value of a complex tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Squared absolute value of data. + """ + assert data.size(-1) == 2 + return (data ** 2).sum(dim=-1) + + +# Helper functions + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data + """ + data = data.numpy() + return data[..., 0] + 1j * data[..., 1] diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/subsample.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/subsample.py new file mode 100644 index 0000000000000000000000000000000000000000..d0620da3414c6077e4293376fb8a9be01ad19990 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/subsample.py @@ -0,0 +1,195 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/transforms.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d9cecc761fc46e201705992ce6226598492f76af --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/dataloaders/transforms.py @@ -0,0 +1,493 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import numpy as np +import torch + +from .math import ifft2c, fft2c, complex_abs +from .subsample import create_mask_for_mask_type, MaskFunc +import random + +from typing import Dict, Optional, Sequence, Tuple, Union +from matplotlib import pyplot as plt +import os + +def rss(data, dim=0): + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Args: + data (torch.Tensor): The input tensor + dim (int): The dimensions along which to apply the RSS transform + + Returns: + torch.Tensor: The RSS value. + """ + return torch.sqrt((data ** 2).sum(dim)) + + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Args: + data (np.array): Input numpy array. + + Returns: + torch.Tensor: PyTorch version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data. + """ + data = data.numpy() + + return data[..., 0] + 1j * data[..., 1] + + +def apply_mask(data, mask_func, seed=None, padding=None): + """ + Subsample given k-space by multiplying with a mask. + + Args: + data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where + dimensions -3 and -2 are the spatial dimensions, and the final dimension has size + 2 (for complex values). + mask_func (callable): A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed (int or 1-d array_like, optional): Seed for the random number generator. + + Returns: + (tuple): tuple containing: + masked data (torch.Tensor): Subsampled k-space data + mask (torch.Tensor): The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + + +def mask_center(x, mask_from, mask_to): + mask = torch.zeros_like(x) + mask[:, :, :, mask_from:mask_to] = x[:, :, :, mask_from:mask_to] + + return mask + + +def center_crop(data, shape): + """ + Apply a center crop to the input real image or batch of real images. + + Args: + data (torch.Tensor): The input tensor to be center cropped. It should + have at least 2 dimensions and the cropping is applied along the + last two dimensions. + shape (int, int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image. + """ + assert 0 < shape[0] <= data.shape[-2] + assert 0 < shape[1] <= data.shape[-1] + + w_from = (data.shape[-2] - shape[0]) // 2 + h_from = (data.shape[-1] - shape[1]) // 2 + w_to = w_from + shape[0] + h_to = h_from + shape[1] + + return data[..., w_from:w_to, h_from:h_to] + + +def complex_center_crop(data, shape): + """ + Apply a center crop to the input image or batch of complex images. + + Args: + data (torch.Tensor): The complex input tensor to be center cropped. It + should have at least 3 dimensions and the cropping is applied along + dimensions -3 and -2 and the last dimensions should have a size of + 2. + shape (int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image + """ + assert 0 < shape[0] <= data.shape[-3] + assert 0 < shape[1] <= data.shape[-2] + + w_from = (data.shape[-3] - shape[0]) // 2 #80 + h_from = (data.shape[-2] - shape[1]) // 2 #80 + w_to = w_from + shape[0] #240 + h_to = h_from + shape[1] #240 + + return data[..., w_from:w_to, h_from:h_to, :] + + +def center_crop_to_smallest(x, y): + """ + Apply a center crop on the larger image to the size of the smaller. + + The minimum is taken over dim=-1 and dim=-2. If x is smaller than y at + dim=-1 and y is smaller than x at dim=-2, then the returned dimension will + be a mixture of the two. + + Args: + x (torch.Tensor): The first image. + y (torch.Tensor): The second image + + Returns: + tuple: tuple of tensors x and y, each cropped to the minimim size. + """ + smallest_width = min(x.shape[-1], y.shape[-1]) + smallest_height = min(x.shape[-2], y.shape[-2]) + x = center_crop(x, (smallest_height, smallest_width)) + y = center_crop(y, (smallest_height, smallest_width)) + + return x, y + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + +class DataTransform(object): + """ + Data Transformer for training U-Net models. + """ + + def __init__(self, which_challenge): + """ + Args: + which_challenge (str): Either "singlecoil" or "multicoil" denoting + the dataset. + mask_func (fastmri.data.subsample.MaskFunc): A function that can + create a mask of appropriate shape. + use_seed (bool): If true, this class computes a pseudo random + number generator seed from the filename. This ensures that the + same mask is used for all the slices of a given volume every + time. + """ + if which_challenge not in ("singlecoil", "multicoil"): + raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"') + + self.which_challenge = which_challenge + + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + """ + Args: + kspace (numpy.array): Input k-space of shape (num_coils, rows, + cols, 2) for multi-coil data or (rows, cols, 2) for single coil + data. + mask (numpy.array): Mask from the test dataset. + target (numpy.array): Target image. + attrs (dict): Acquisition related information stored in the HDF5 + object. + fname (str): File name. + slice_num (int): Serial number of the slice. + + Returns: + (tuple): tuple containing: + image (torch.Tensor): Zero-filled input image. + target (torch.Tensor): Target image converted to a torch + Tensor. + mean (float): Mean value used for normalization. + std (float): Standard deviation value used for normalization. + fname (str): File name. + slice_num (int): Serial number of the slice. + """ + kspace = to_tensor(kspace) + + # inverse Fourier transform to get zero filled solution + image = ifft2c(kspace) + + # crop input to correct size + if target is not None: + crop_size = (target.shape[-2], target.shape[-1]) + else: + crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) + + # check for sFLAIR 203 + if image.shape[-2] < crop_size[1]: + crop_size = (image.shape[-2], image.shape[-2]) + + image = complex_center_crop(image, crop_size) + + # getLR + imgfft = fft2c(image) + imgfft = complex_center_crop(imgfft, (160, 160)) + LR_image = ifft2c(imgfft) + + # absolute value + LR_image = complex_abs(LR_image) + + # normalize input + LR_image, mean, std = normalize_instance(LR_image, eps=1e-11) + LR_image = LR_image.clamp(-6, 6) + + # normalize target + if target is not None: + target = to_tensor(target) + target = center_crop(target, crop_size) + target = normalize(target, mean, std, eps=1e-11) + target = target.clamp(-6, 6) + else: + target = torch.Tensor([0]) + + return LR_image, target, mean, std, fname, slice_num + +class DenoiseDataTransform(object): + def __init__(self, size, noise_rate): + super(DenoiseDataTransform, self).__init__() + self.size = (size, size) + self.noise_rate = noise_rate + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + max_value = attrs["max"] + + #target + target = to_tensor(target) + target = center_crop(target, self.size) + target, mean, std = normalize_instance(target, eps=1e-11) + target = target.clamp(-6, 6) + + #image + kspace = to_tensor(kspace) + complex_image = ifft2c(kspace) #complex_image + image = complex_center_crop(complex_image, self.size) + noise_image = self.rician_noise(image, max_value) + noise_image = complex_abs(noise_image) + + noise_image = normalize(noise_image, mean, std, eps=1e-11) + noise_image = noise_image.clamp(-6, 6) + + return noise_image, target, mean, std, fname, slice_num + + + def rician_noise(self, X, noise_std): + #Add rician noise with variance sampled uniformly from the range 0 and 0.1 + noise_std = random.uniform(0, noise_std*self.noise_rate) + Ir = X + noise_std * torch.randn(X.shape) + Ii = noise_std*torch.randn(X.shape) + In = torch.sqrt(Ir ** 2 + Ii ** 2) + return In + + +def apply_mask( + data: torch.Tensor, + mask_func: MaskFunc, + seed: Optional[Union[int, Tuple[int, ...]]] = None, + padding: Optional[Sequence[int]] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Subsample given k-space by multiplying with a mask. + Args: + data: The input k-space data. This should have at least 3 dimensions, + where dimensions -3 and -2 are the spatial dimensions, and the + final dimension has size 2 (for complex values). + mask_func: A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed: Seed for the random number generator. + padding: Padding value to apply for mask. + Returns: + tuple containing: + masked data: Subsampled k-space data + mask: The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + + + +class ReconstructionTransform(object): + """ + Data Transformer for training U-Net models. + """ + + def __init__(self, which_challenge, mask_func=None, use_seed=True): + """ + Args: + which_challenge (str): Either "singlecoil" or "multicoil" denoting + the dataset. + mask_func (fastmri.data.subsample.MaskFunc): A function that can + create a mask of appropriate shape. + use_seed (bool): If true, this class computes a pseudo random + number generator seed from the filename. This ensures that the + same mask is used for all the slices of a given volume every + time. + """ + if which_challenge not in ("singlecoil", "multicoil"): + raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"') + + self.mask_func = mask_func + self.which_challenge = which_challenge + self.use_seed = use_seed + + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + """ + Args: + kspace (numpy.array): Input k-space of shape (num_coils, rows, + cols, 2) for multi-coil data or (rows, cols, 2) for single coil + data. + mask (numpy.array): Mask from the test dataset. + target (numpy.array): Target image. + attrs (dict): Acquisition related information stored in the HDF5 + object. + fname (str): File name. + slice_num (int): Serial number of the slice. + + Returns: + (tuple): tuple containing: + image (torch.Tensor): Zero-filled input image. + target (torch.Tensor): Target image converted to a torch + Tensor. + mean (float): Mean value used for normalization. + std (float): Standard deviation value used for normalization. + fname (str): File name. + slice_num (int): Serial number of the slice. + """ + kspace = to_tensor(kspace) + + # apply mask + if self.mask_func: + seed = None if not self.use_seed else tuple(map(ord, fname)) + masked_kspace, mask = apply_mask(kspace, self.mask_func, seed) + # print("mask shape", mask.shape, mask.sum()) + # mask shape torch.Size([1, 368, 1]) tensor(89.) + + else: + masked_kspace = kspace + # print("masked_kspace shape", masked_kspace.shape) + + # inverse Fourier transform to get zero filled solution + image = ifft2c(masked_kspace) + + # crop input to correct size + if target is not None: + crop_size = (target.shape[-2], target.shape[-1]) + else: + crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) + + # check for sFLAIR 203 + if image.shape[-2] < crop_size[1]: + crop_size = (image.shape[-2], image.shape[-2]) + + image = complex_center_crop(image, crop_size) + # print('image',image.shape) + # absolute value + image = complex_abs(image) + + # apply Root-Sum-of-Squares if multicoil data + if self.which_challenge == "multicoil": + image = rss(image) + + # normalize input + image, mean, std = normalize_instance(image, eps=1e-11) + image = image.clamp(-6, 6) + + # normalize target + if target is not None: + target = to_tensor(target) + target = center_crop(target, crop_size) + target = normalize(target, mean, std, eps=1e-11) + target = target.clamp(-6, 6) + else: + target = torch.Tensor([0]) + + return image, target, mean, std, fname, slice_num + + +def build_transforms(args, mode = 'train', use_kspace=False): + + challenge = 'singlecoil' + if use_kspace: + return ReconstructionTransform(challenge) + + else: + if mode == 'train': + mask = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + return ReconstructionTransform(challenge, mask, use_seed=False) + elif mode == 'val': + mask = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + return ReconstructionTransform(challenge, mask) + else: + return ReconstructionTransform(challenge) + + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/debug/True_0_0.png b/MRI_recon/code/Frequency-Diffusion/FSMNet/debug/True_0_0.png new file mode 100644 index 0000000000000000000000000000000000000000..31a3870d9dfabab22cda715602835720e5b02068 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/debug/True_0_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bd3f386ad1d7bce293eb17089e4e0be9b7c33f3eadb46b303a0b445f9e6f07d4 +size 128798 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/documents/INSTALL.md b/MRI_recon/code/Frequency-Diffusion/FSMNet/documents/INSTALL.md new file mode 100644 index 0000000000000000000000000000000000000000..9912721cb3354240d99c08838ae8d2b1417b339b --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/documents/INSTALL.md @@ -0,0 +1,11 @@ +## Dependency +The code is tested on `python 3.8, Pytorch 1.13`. + +##### Setup environment + +```bash +conda create -n FSMNet python=3.8 +source activate FSMNet # or conda activate FSMNet +pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 +pip install einops h5py matplotlib scikit_image tensorboardX yacs pandas opencv-python timm ml_collections +``` diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/__init__.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe1e0f26ca039d666189f901309dbb9adfbadc89 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/__init__.py @@ -0,0 +1,2 @@ +from .frequency_noise import add_frequency_noise +from .degradation.k_degradation import get_ksu_kernel, apply_tofre, apply_to_spatial \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/__pycache__/__init__.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c9c9a8879fd669fe203e201da9baafda12ed989 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/__pycache__/frequency_noise.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/__pycache__/frequency_noise.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1af7bffc85e841cfeeaa1663fc8278285bfbca7c Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/__pycache__/frequency_noise.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/__init__.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/__pycache__/__init__.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c56c994142bc9cab23e625ab9439c87547390277 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/__pycache__/__init__.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/__pycache__/k_degradation.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/__pycache__/k_degradation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58c267426152d5c172358d3a30a101fc03376d02 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/__pycache__/k_degradation.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/__pycache__/mask_utils.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/__pycache__/mask_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91d28c6fdb1cd81b53e31966a8618bdc2b2eadd9 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/__pycache__/mask_utils.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/extract_example_mask.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/extract_example_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1955ea8bf7c2e7d678e80063002dfc6572e7b9 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/extract_example_mask.py @@ -0,0 +1,71 @@ +import matplotlib.pyplot as plt +import torch +import numpy as np +from torch.fft import fft2, ifft2, fftshift, ifftshift + +# brats 4X +example = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_4X/BraTS20_Training_036_90_t2_4X_undermri.png" +gt = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_4X/BraTS20_Training_036_90_t2.png" +save_file = "./example_mask/brats_4X_mask.npy" + + +example = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_8X/BraTS20_Training_036_90_t2_8X_undermri.png" +gt = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_8X/BraTS20_Training_036_90_t2.png" +save_file = "./example_mask/brats_8X_mask.npy" + + + +example_img = plt.imread(example) # cv2.imread(example_frequency_img, cv2.IMREAD_GRAYSCALE) +gt = plt.imread(gt) # cv2.imread(example_frequency_img, cv2.IMREAD_GRAYSCALE) + +print("example_img shape: ", example_img.shape) +plt.imshow(example_img, cmap='gray') +plt.title("Example Frequency Image") +plt.show() + +example_img = torch.from_numpy(example_img).float() +fre = fftshift(fft2(example_img)) # ) +amp = torch.log(torch.abs(fre)) +plt.imshow(amp.squeeze(0).squeeze(0).numpy()) +plt.show() +angle = torch.angle(fre) +plt.imshow(angle.squeeze(0).squeeze(0).numpy()) +plt.show() + + +gt_fre = fftshift(fft2(torch.from_numpy(gt).float())) # ) +gt_amp = torch.log(torch.abs(gt_fre)) +plt.imshow(gt_amp.squeeze(0).squeeze(0).numpy()) +plt.show() +gt_angle = torch.angle(gt_fre) +plt.imshow(gt_angle.squeeze(0).squeeze(0).numpy()) +plt.show() + + +amp_mask = gt_amp.squeeze(0).squeeze(0).numpy() - amp.squeeze(0).squeeze(0).numpy() +amp_mask = np.mean(amp_mask, axis=0, keepdims=True) + +print("amp_mask shape: ", amp_mask) +thres = np.mean(amp_mask) +amp_mask[amp_mask < thres] = 1 +amp_mask[amp_mask >= thres] = 0 + + +#duplicate +amp_mask = np.repeat(amp_mask, 240, axis=0) + +plt.imshow(gt_amp.squeeze(0).squeeze(0).numpy() - amp.squeeze(0).squeeze(0).numpy()) +plt.show() +plt.imshow(gt_angle.squeeze(0).squeeze(0).numpy() - angle.squeeze(0).squeeze(0).numpy()) +plt.show() +plt.imshow(amp_mask) +plt.show() + +np.save(save_file, amp_mask) +# + + +load_backmask = np.load(save_file) +plt.imshow(load_backmask) +plt.show() + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/k_degradation.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/k_degradation.py new file mode 100644 index 0000000000000000000000000000000000000000..f200ce4b6a7634c91e2836e3562a1baef30de674 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/k_degradation.py @@ -0,0 +1,439 @@ +# --------------------------- +# Fade kernels +# --------------------------- +import cv2, torch +import numpy as np +import torchgeometry as tgm +from torch.fft import fft2, ifft2, fftshift, ifftshift, fftn, ifftn + +try: + from frequency_diffusion.degradation.mask_utils import RandomMaskFunc, EquispacedMaskFunc +except: + from mask_utils import RandomMaskFunc, EquispacedMaskFunc + + +from torch import nn +import matplotlib.pyplot as plt + +def get_fade_kernel(dims, std): + fade_kernel = tgm.image.get_gaussian_kernel2d(dims, std) + fade_kernel = fade_kernel / torch.max(fade_kernel) + fade_kernel = torch.ones_like(fade_kernel) - fade_kernel + # if device_of_kernel == 'cuda': + # fade_kernel = fade_kernel.cuda() + fade_kernel = fade_kernel[1:, 1:] + return fade_kernel + + + +def get_fade_kernels(fade_routine, num_timesteps, image_size, kernel_std,initial_mask): + kernels = [] + for i in range(num_timesteps): + if fade_routine == 'Incremental': + kernels.append(get_fade_kernel((image_size + 1, image_size + 1), + (kernel_std * (i + initial_mask), + kernel_std * (i + initial_mask)))) + elif fade_routine == 'Constant': + kernels.append(get_fade_kernel( + (image_size + 1, image_size + 1), + (kernel_std, kernel_std))) + + elif fade_routine == 'Random_Incremental': + kernels.append(get_fade_kernel((2 * image_size + 1, 2 * image_size + 1), + (kernel_std * (i + initial_mask), + kernel_std * (i + initial_mask)))) + return torch.stack(kernels) + + +# --------------------------- +# Kspace kernels +# --------------------------- +# cartesian_regular +def get_mask_func(mask_method, af, cf): + if mask_method == 'cartesian_regular': + return EquispacedMaskFractionFunc(center_fractions=[cf], accelerations=[af]) + + elif mask_method == 'cartesian_random': + return RandomMaskFunc(center_fractions=[cf], accelerations=[af]) + + elif mask_method == "random": + return RandomMaskFunc([cf], [af]) + + elif mask_method == "randompatch": + return RandomPatchFunc([cf], [af]) + + elif mask_method == "equispaced": + return EquispacedMaskFunc([cf], [af]) + + else: + raise NotImplementedError + + +use_fix_center_ratio = False + +class Noisy_Patch(nn.Module): + def __init__(self): + super(Noisy_Patch, self).__init__() + self.af_list = [] + self.cf_list = [] + self.fe_list = [] + self.pe_list = [] + self.seed = 0 + + def append_list(self, at, cf, fe, pe): + self.af_list.append(at) + self.cf_list.append(cf) + self.fe_list.append(fe) + self.pe_list.append(pe) + + def get_noisy_patches(self, t): + af = self.af_list[t] + cf = self.cf_list[t] + fe = self.fe_list[t] + pe = self.pe_list[t] + + patch_mask = get_mask_func("randompatch", af, cf) + mask_, _ = patch_mask((fe, pe, 1), seed=self.seed) # mask (numpy): (fe, pe) + return mask_ + + def forward(self, mask, ts): + # Step 1, Random Drop elements to learn intra-stride effects, patch-wise + # print("use_patch_kernel forward:", t) + # print("mask = ", mask.shape) + # masks_ = [] + for id, t in enumerate(ts): + mask_ = self.get_noisy_patches(t)[0] + # print("mask_ = ", mask_.shape) + # print("mask[id, t] =", mask[t].shape) + + mask[t] = mask_.to(mask[t].device) * mask[t] + self.seed += ts[0].item() + + # masks_ = torch.stack(masks_).cuda() + # print("masks_ = ", masks_.shape) + # print("mask = ", mask.shape) # B, T, H, W + + return mask + +get_noisy_patches = Noisy_Patch() + + +def get_ksu_mask(mask_method, af, cf, pe, fe, seed=0, is_training=False): + mask_func = get_mask_func(mask_method, af, cf) # acceleration factor, center fraction + + if mask_method in ['cartesian_regular', 'cartesian_random', 'equispaced']: + # print("pe") + mask = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + # prepare noisy patches + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, fe, pe) + + elif mask_method == 'equispaced': + mask = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + # prepare noisy patches + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, fe, pe) + + elif mask_method == 'gaussian_2d': + mask, _ = mask_func(resolution=(fe, pe), accel=af, sigma=100, seed=seed) # mask (numpy): (fe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (fe, pe) --> (1, fe, pe) + + elif mask_method in ['radial_add', 'radial_sub', 'spiral_add', 'spiral_sub']: + sr = 1 / af + mask = mask_func(mask_type=mask_method, mask_sr=sr, res=pe, seed=seed) # mask (numpy): (pe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (pe, pe) --> (1, pe, pe) + + else: + raise NotImplementedError + + # print("return mask = ", mask.shape) + return mask + + + +def get_ksu_kernel(timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=4, + center_fraction=0.08, + accelerate_mask=None): + + if accelerated_factor == 4: + mask_method, center_fraction = "cartesian_random", center_fraction #0.08 # 0.15 + + elif accelerated_factor == 8: + mask_method, center_fraction = "equispaced", center_fraction # 0.04 + + elif accelerated_factor == 12: + mask_method, center_fraction = "equispaced", center_fraction + + + center_ratio_factor = center_fraction * accelerated_factor + + masks = [] + noisy_masks = [] + ksu_mask_pe = ksu_mask_fe = image_size # , ksu_mask_pe=320, ksu_mask_fe=320 + # ksu_mask_fe + if ksu_routine == 'LinearSamplingRate': + # Generate the sampling rate list with torch.linspace, reversed, and skip the first element + sr_list = torch.linspace(start=1/accelerated_factor, end=1, steps=timesteps + 1).flip(0) + sr_list = [sr.item() for sr in sr_list] + # Start from 0.01 + for sr in sr_list: + sr = sr.item() + af = 1 / sr # * accelerated_factor # acceleration factor + cf = center_fraction if use_fix_center_ratio else sr_list[0] * center_ratio_factor + + masks.append(get_ksu_mask(mask_method, af, cf, pe=ksu_mask_pe, fe=ksu_mask_fe)) + + elif ksu_routine == 'LogSamplingRate': + + # Generate the sampling rate list with torch.logspace, reversed, and skip the first element + sr_list = torch.logspace(start=-torch.log10(torch.tensor(accelerated_factor)), + end=0, steps=timesteps + 1).flip(0) + + sr_list = [sr.item() for sr in sr_list] + af = 1 / sr_list[-1] + cf = center_fraction if use_fix_center_ratio else sr_list[-1] * center_ratio_factor + + + if isinstance(accelerate_mask, type(None)): + cache_mask = get_ksu_mask(mask_method, af, cf, pe=ksu_mask_pe, fe=ksu_mask_fe) + print("cache_mask = ", cache_mask.shape) # torch.Size([1, 320, 320]) + else: + cache_mask = accelerate_mask + + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, ksu_mask_fe, ksu_mask_pe) + + masks.append(cache_mask) + + sr_list = sr_list[:-1][::-1] #.flip(0) # Flip? + + for sr in sr_list: + af = 1 / sr + cf = center_fraction if use_fix_center_ratio else sr * center_ratio_factor + # print("af = ", af, cf) + + H, W = cache_mask.shape[1], cache_mask.shape[2] + new_mask = cache_mask.clone() + + # Add additional lines to the mask based on new acceleration factor + total_lines = H + sampled_lines = int(total_lines / af) + existing_lines = new_mask.squeeze(0).sum(dim=0).nonzero(as_tuple=True)[0].tolist() + + remaining_lines = [i for i in range(total_lines) if i not in existing_lines] + + if sampled_lines > len(existing_lines): + center = W // 2 + additional_lines = sampled_lines - len(existing_lines) # sample number + + sorted_indices = sorted(remaining_lines, key=lambda x: abs(x - center)) + + # Take the closest `additional_lines` indices + sampled_indices = sorted_indices[:additional_lines] + + # Remove sampled indices from remaining_lines + for idx in sampled_indices: + remaining_lines.remove(idx) + + # Update new_mask for each sampled index + for idx in sampled_indices: + new_mask[:, :, idx] = 1.0 + + + + cache_mask = new_mask + + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, ksu_mask_fe, ksu_mask_pe) + + + masks.append(cache_mask) + + # reverse + masks = masks[::-1] + noisy_masks = masks # noisy_masks[::-1] + + + elif mask_method == 'gaussian_2d': + raise NotImplementedError("Gaussian 2D mask type is not implemented.") + + else: + raise NotImplementedError(f'Unknown k-space undersampling routine {ksu_routine}') + + # Return masks, excluding the first one + return masks[1:] + + + +class high_fre_mask: + def __init__(self): + self.mask_cache = {} + + def __call__(self, H, W): + if (H, W) in self.mask_cache: + return self.mask_cache[(H, W)] + center_x, center_y = H // 2, W // 2 + radius = H//8 # 影响的频率范围半径 + + high_freq_mask = torch.ones(H, W) + for i in range(H): + for j in range(W): + if (i - center_x) ** 2 + (j - center_y) ** 2 <= radius ** 2: + high_freq_mask[i, j] = 0.0 + self.mask_cache[(H, W)] = high_freq_mask + return high_freq_mask + + +high_fre_mask_cls = high_fre_mask() + + + +def apply_ksu_kernel(x_start, mask): + fft, mask = apply_tofre(x_start, mask) + fft = fft * mask + x_ksu = apply_to_spatial(fft) + + return x_ksu + +# from dataloaders.math import ifft2c, fft2c, complex_abs + +def apply_tofre(x_start, mask): + # B, C, H, W = x_start.shape + kspace = fftshift(fft2(x_start, norm=None, dim=(-2, -1)), dim=(-2, -1)) # Default: all dimensions + mask = mask.to(kspace.device) + return kspace, mask + +def apply_to_spatial(fft): + x_ksu = ifft2(ifftshift(fft, dim=(-2, -1)), norm=None, dim=(-2, -1)) # ortho + # After ifftn, the output is already in the spatial domain + x_ksu = x_ksu.real #torch.abs(x_ksu) # + return x_ksu + + + +if __name__ == "__main__": + # First STEP + import matplotlib.pyplot as plt + import numpy as np + + image_size = 64 + + masks = get_ksu_kernel(25, image_size, + "LinearSamplingRate", is_training=True) # LogSamplingRate + + + batch_size = 1 + + img = plt.imread("/Users/haochen/Documents/GitHub/DiffDecomp/Cold-Diffusion/defading-diffusion-pytorch/diffusion_pytorch/assets/img.png") + img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LINEAR) + # to gray scale + img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + + # img = np.transpose(img, (2, 0, 1)) + # img = img[0] + img = np.expand_dims(img, axis=0) + img = torch.from_numpy(img).unsqueeze(0).float() + print(" img shape: ", img.shape, img.max(), img.min()) + + rand_kernels = [] + rand_x = torch.randint(0, image_size + 1, (batch_size,)).long() + print("rand_x shape:", rand_x.shape, rand_x) + + img = img * 2 - 1 # + + masked_img = [] + + for m in masks: + m = m.unsqueeze(0) + img = apply_ksu_kernel(img, m, pixel_range='-1_1', ) + masked_img.append(img) + + masks = np.concatenate(masks, axis=-1)[0] + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + # masked_img = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY) + + print(" masked_img shape: ", masked_img.shape) + print(" mask shape: ", masks.shape) + + img = np.concatenate([masks, masked_img], axis=0) + + plt.figure(figsize=(100, 10)) + plt.imshow(img, cmap='gray') # (1, 128, 1280) + plt.show() + + print("\n\nSecond stage...") + + + # Second STEP + import matplotlib.pyplot as plt + import numpy as np + + image_size = 64 + batch_size = 1 + t = 25 + kspace_kernels = get_ksu_kernel(t, image_size, ksu_routine="LogSamplingRate", is_training=True) # 2 * + kspace_kernels = torch.stack(kspace_kernels).squeeze(1) + + img = plt.imread( + "/Users/haochen/Documents/GitHub/DiffDecomp/Cold-Diffusion/generation-diffusion-pytorch/diffusion_pytorch/assets/img.png") + img = cv2.resize(img, (image_size, image_size)) + + img = np.transpose(img, (2, 0, 1)) + img = img[0] + img = np.expand_dims(img, axis=0) + img = torch.from_numpy(img).unsqueeze(0).float() + print(" img shape: ", img.shape, img.max(), img.min()) + + rand_kernels = [] + rand_x = torch.randint(0, image_size + 1, (batch_size,)).long() + + print("rand_x shape:", rand_x.shape, rand_x) + + for i in range(batch_size): + print("kspace_kernels[j] shape = ", kspace_kernels[i].shape, rand_x[i]) + # k = kspace_kernels[j].clone() + rand_kernels.append(torch.stack( + [kspace_kernels[j][: image_size, # rand_x[i]:rand_x[i] + + : image_size] for j in + range(len(kspace_kernels))])) + + rand_kernels = torch.stack(rand_kernels) + + # rand_kernels shape: torch.Size([24, 5, 128, 128]) + print("=== rand_kernels: ", rand_kernels.shape, kspace_kernels[0].shape) + + masked_img = [] + masks = [] + for i in range(t): + k = torch.stack([rand_kernels[:, i]], 1)[0] + masks.append(k) + # print("-- k shape: ", k.shape) + # print("-- img shape: ", img.shape) + + img = apply_ksu_kernel(img, k, pixel_range='0_1') + masked_img.append(img) + + masks = np.concatenate(masks, axis=-1)[0] + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + # masked_img = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY) + + # print(" masked_img shape: ", masked_img.shape) + # print(" mask shape: ", masks.shape) + + img = np.concatenate([masks, masked_img], axis=0) + + plt.figure(figsize=(100, 10)) + plt.imshow(img, cmap='gray') # (1, 128, 1280) + plt.show() + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/kspace_test.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/kspace_test.py new file mode 100644 index 0000000000000000000000000000000000000000..aa490fe2ac1a25366b0750bd5fe3d4b785c414ff --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/kspace_test.py @@ -0,0 +1,274 @@ +# --------------------------- +# Fade kernels +# --------------------------- +import cv2, torch +import numpy as np +import torchgeometry as tgm +from torch.fft import fft2, ifft2, fftshift, ifftshift +import matplotlib.pyplot as plt +from mask_utils import RandomMaskFunc, EquispacedMaskFunc + + +try: + from mask_utils import RandomMaskFunc, EquispacedMaskFunc +except: + from mask_utils import RandomMaskFunc, EquispacedMaskFractionFunc, EquispacedMaskFunc, RandomPatchFunc + +try: + from .k_degradation import get_fade_kernel, get_fade_kernels, get_mask_func, high_fre_mask, get_ksu_kernel +except: + from k_degradation import get_fade_kernel, get_fade_kernels, get_mask_func, high_fre_mask, get_ksu_kernel + + +use_fix_center_ratio = False + + +def get_ksu_mask(mask_method, af, cf, pe, fe, seed=0, is_training=False): + mask_func = get_mask_func(mask_method, af, cf) # acceleration factor, center fraction + + if mask_method in ['cartesian_regular', 'cartesian_random']: + + mask, num_low_freq = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + if is_training: # Step 1, Random Drop elements to learn intra-stride effects, patch-wise + af_new = 1.0 + (af - 1.0) / 2 + # af_new = max(af_new, 1.0) + + patch_mask = get_mask_func("randompatch", af_new, cf) + mask_, _ = patch_mask((fe, pe, 1), seed=seed) # mask (numpy): (fe, pe) + + mask = mask_ * mask + + + elif mask_method == 'gaussian_2d': + mask, _ = mask_func(resolution=(fe, pe), accel=af, sigma=100, seed=seed) # mask (numpy): (fe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (fe, pe) --> (1, fe, pe) + + elif mask_method in ['radial_add', 'radial_sub', 'spiral_add', 'spiral_sub']: + sr = 1 / af + mask = mask_func(mask_type=mask_method, mask_sr=sr, res=pe, seed=seed) # mask (numpy): (pe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (pe, pe) --> (1, pe, pe) + + else: + raise NotImplementedError + + # print("return mask = ", mask.shape) + return mask + + +# ksu_masks = get_ksu_kernels() +# (C, H, W) --> (B, C, H, W) + + +high_fre_mask_cls = high_fre_mask() + + +def apply_ksu_kernel(x_start, mask, params_dict=None, pixel_range='mean_std', + use_fre_noise=False, return_mask=False): + fft, mask = apply_tofre(x_start, mask, params_dict, pixel_range) + + # Use the high frequency mask to add noise + if use_fre_noise: + fft = fft * mask + fft_magnitude = torch.abs(fft) # 幅度 + fft_phase = torch.angle(fft) # 相位 + + _, _, H, W = fft.shape + + high_freq_mask = high_fre_mask_cls(H, W).to(fft.device) + high_freq_mask = high_freq_mask.unsqueeze(0).unsqueeze(0).repeat(fft.shape[0], 1, 1, 1) + + # Background Noise + sigma = 0.2 + noise = torch.randn_like(fft_magnitude) * sigma + mean_mag = fft_magnitude.sum() / (mask.sum() + 1) + + noise_magnitude_high = noise * (mean_mag) * (1 - mask) # high_freq_mask + + sigma = 0.1 + noise = torch.randn_like(fft_magnitude) * sigma + noise_magnitude_low = noise * fft_magnitude * mask # (1 - high_freq_mask) + + # fft_noisy_magnitude = fft_magnitude * mask + noise_magnitude * high_freq_mask * (1 - mask) + fft_noisy_magnitude = fft_magnitude * mask + fft_noisy_magnitude += noise_magnitude_high + noise_magnitude_low + fft_noisy_magnitude = torch.clamp(fft_noisy_magnitude, min=0.0) + + fft = fft_noisy_magnitude * torch.exp(1j * fft_phase) + + else: + fft = fft * mask + + x_ksu = apply_to_spatial(fft, params_dict, pixel_range) + if return_mask: + return x_ksu, fft, fft_magnitude + + return x_ksu + + +def apply_tofre(x_start, mask, params_dict=None, pixel_range='mean_std'): + fft = fftshift(fft2(x_start)) + mask = mask.to(fft.device) + return fft, mask # , _min, _max + + +def apply_to_spatial(fft, params_dict=None, pixel_range='mean_std'): + x_ksu = ifft2(ifftshift(fft)) + x_ksu = torch.abs(x_ksu) + + return x_ksu + + +if __name__ == "__main__": + # First STEP + import SimpleITK as sitk + + import numpy as np + import os + + image_size = 240 + batch_size = 1 + t = 5 + + + + + use_linux = True + + # Load MRI back here + if use_linux: + root = "/gamedrive/Datasets/medical/Brain/brats/ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData" + p_id = 639 + modality = "T1C" + filename = f"{root}/BraTS-GLI-{p_id:05d}-000/BraTS-GLI-{p_id:05d}-000-{modality.lower()}.nii.gz" + img_obj = sitk.ReadImage(filename) + img_array = sitk.GetArrayFromImage(img_obj) + + slice = img_array.shape[0] // 2 + img = img_array[slice, ...] + plt.imshow(img, cmap="gray") + plt.show() + img = (img - img.min()) / (img.max() - img.min()) + + plt.imsave("visualization/original.png", img, cmap="gray") + + else: + # Or use PNG + img = plt.imread( + "/Users/haochen/Documents/GitHub/DiffDecomp/Cold-Diffusion/generation-diffusion-pytorch/diffusion_pytorch/assets/img.png") + img = np.transpose(img, (2, 0, 1))[0] + + img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LINEAR) + print("img shape=", img.shape) + + img = np.expand_dims(img, axis=0) + img = torch.from_numpy(img).unsqueeze(0).float() + + ksu_routine = "LogSamplingRate" # "LinearSamplingRate" # + kspace_kernels, patch_drop_masks = get_ksu_kernel(t, image_size, + ksu_routine=ksu_routine, is_training=True, + example_frequency_img=example) + kspace_kernels = torch.stack(kspace_kernels).squeeze(1) + + # all k_space + rand_kernels = [] + rand_x = torch.randint(0, image_size + 1, (batch_size,)).long() + + for i in range(batch_size): + # k = kspace_kernels[j].clone() + rand_kernels.append(torch.stack( + [kspace_kernels[j][: image_size, # rand_x[i]:rand_x[i] + + : image_size] for j in + range(len(kspace_kernels))])) + + rand_kernels = torch.stack(rand_kernels) + + # rand_kernels shape: torch.Size([24, 5, 128, 128]) + ori_img = img + + masked_img = [] + masks = [] + for i in range(t): + k = torch.stack([rand_kernels[:, i]], 1)[0] + masks.append(k) + + img = apply_ksu_kernel(img, k, pixel_range='0_1') + + masked_img.append(img) + + # Visualize the masks + masks = np.concatenate(masks, axis=-1)[0] + + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + # Save individually + + print("masks / masked_img=", masks.max(), masked_img.max()) + # img = np.concatenate([masks, masked_img], axis=0) + + plt.imsave("visualization/sample_masks.png", masks, cmap='gray') + + # masked_img = (masked_img - masked_img.min())/(masked_img) + # masked_img = np.concatenate([masked_img, 1-masked_img], axis=0) + plt.imsave("visualization/sample_images.png", masked_img, cmap='gray') + + w = masked_img.shape[0] + pr_folder = "visualization/progressive" + os.makedirs(pr_folder, exist_ok=True) + + # Progressive + print() + for i in range(t): + plt.imsave(f"{pr_folder}/{i}_masked_img.png", masked_img[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_masks.png", masks[:, i * w: (i + 1) * w], cmap='gray') + + img = ori_img + # use_fre_noise=False, return_mask=False + masked_img = [] + masks = [] + fft = [] + ks = [] + for i in range(t): + k = torch.stack([rand_kernels[:, i]], 1)[0] + ks.append(k) + + img, k, fft_original = apply_ksu_kernel(img, k, pixel_range='0_1', use_fre_noise=True, return_mask=True) + + # k -> fft + fft_magnitude = np.abs(k) # 幅度 + # fft_phase = torch.angle(k) # 相位 + + mag = np.log(fft_magnitude[0]) + masks.append(mag) + fft.append(np.log(fft_original[0])) + + masked_img.append(img) + + # Visualize the masks + masks = np.concatenate(masks, axis=-1)[0] + ks = np.concatenate(ks, axis=-1)[0] + + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + + fft = np.concatenate(fft, axis=-1)[0] + + plt.imsave("visualization/sample_noisy_mask.png", masks, cmap='gray') + + # masked_img = np.concatenate([masked_img, 1 - masked_img], axis=0) + plt.imsave("visualization/sample_noisy_image.png", masked_img, cmap='gray') + # print("masked_img shape=", masked_img.shape, w) + + # Progressive + for i in range(t): + # print("masked_img[:, t*w: (t+1)*w] = ", masked_img[:, t*w: (t+1)*w].shape, t*w) + + plt.imsave(f"{pr_folder}/{i}_n_masked_img.png", masked_img[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_n_masks.png", masks[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_n_fft.png", fft[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_n_ks.png", ks[:, i * w: (i + 1) * w], cmap='gray') + + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/mask_utils.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6eda8a0397fb628cabc4e1d97f93ae9db37377f3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/degradation/mask_utils.py @@ -0,0 +1,196 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # print("center_fraction = ", center_fraction) + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/frequency_noise.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/frequency_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..275fddb66f47bc02036fca5d31ec121b55939baa --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/frequency_diffusion/frequency_noise.py @@ -0,0 +1,39 @@ +import torch + +def add_frequency_noise(fft, snr=10, vacant_snr=15, mask=None): + ### 根据SNR确定noise的放大比例 + num_pixels = fft.numel() + + fft_magnitude = torch.abs(fft) + fft_phase = torch.angle(fft) + + # fft_magnitude + mag_psr = torch.mean(torch.abs(fft_magnitude) ** 2) + mag_pnr = mag_psr / (10 ** (snr / 10)) # Calculate noise power + noise_mag = torch.randn_like(fft_magnitude) * torch.sqrt(mag_pnr) + + mag_psr_vacant = mag_psr / (10 ** (vacant_snr / 10)) + noise_mag_vacant = torch.randn_like(fft_magnitude) * torch.sqrt(mag_psr_vacant) + + fft_magnitude = fft_magnitude + \ + noise_mag * fft_magnitude * mask + \ + noise_mag_vacant * (1- mask) + fft_magnitude = torch.abs(fft_magnitude) + + # fft_phase + pha_psr = torch.mean(torch.abs(fft_phase) ** 2) + pha_pnr = pha_psr / (10 ** (snr / 10)) # Calculate noise power for phase + noise_pha = torch.randn_like(fft_phase) * torch.sqrt(pha_pnr) + + pha_psr_vacant = pha_psr / (10 ** (vacant_snr / 10)) + noise_pha_vacant = torch.randn_like(fft_phase) * torch.sqrt(pha_psr_vacant) + + fft_phase = fft_phase + \ + noise_pha * fft_phase * mask + \ + noise_pha_vacant * (1- mask) + + noise_fft = fft_magnitude * torch.exp(1j * fft_phase) + + return noise_fft + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/knee_data_split/singlecoil_train_split_less.csv b/MRI_recon/code/Frequency-Diffusion/FSMNet/knee_data_split/singlecoil_train_split_less.csv new file mode 100644 index 0000000000000000000000000000000000000000..d85707318750900b14a6e7100541242a60b7a310 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/knee_data_split/singlecoil_train_split_less.csv @@ -0,0 +1,227 @@ +file1000685,file1000568,0.301723929779229 +file1002273,file1000481,0.302226224199571 +file1000472,file1000142,0.304272730770318 +file1002186,file1000863,0.304812175768496 +file1002385,file1002518,0.305357274240413 +file1000981,file1000129,0.305533361411383 +file1001320,file1001948,0.306821514316368 +file1000633,file1002243,0.306892354331709 +file1001872,file1001294,0.308345907393103 +file1001474,file1001830,0.310481695157561 +file1001005,file1001283,0.310497722435023 +file1001690,file1001519,0.310709448786299 +file1002469,file1001811,0.31193137253455 +file1000914,file1000242,0.31237190359308 +file1002284,file1002012,0.315366393843169 +file1001721,file1001328,0.31735122361847 +file1000807,file1002334,0.320096908959039 +file1001944,file1002335,0.320272061156991 +file1002090,file1002431,0.320351887633851 +file1000499,file1002063,0.320786426659383 +file1001362,file1000509,0.32175341740359 +file1001421,file1000597,0.324291432700032 +file1000349,file1000321,0.324545110048573 +file1002123,file1001235,0.327142348994532 +file1001867,file1002086,0.328624781732941 +file1001007,file1001027,0.330759860300298 +file1001915,file1000088,0.331499371283099 +file1001661,file1000313,0.331905252950291 +file1000383,file1000307,0.339998107225229 +file1000116,file1000632,0.34069458535013 +file1002303,file1000173,0.343821267871409 +file1000306,file1001277,0.344751178043605 +file1000003,file1001922,0.346138116633394 +file1000109,file1000143,0.347632265547478 +file1001999,file1000115,0.348248659775587 +file1000089,file1000326,0.348964657514049 +file1001205,file1002232,0.349375610862454 +file1000557,file1000619,0.351305005151048 +file1001823,file1000778,0.352076809462453 +file1000806,file1001130,0.352659078122633 +file1000365,file1000351,0.352772816610486 +file1002374,file1001778,0.352974481603711 +file1002516,file1001910,0.359896103026675 +file1001200,file1000931,0.360070003966827 +file1001479,file1000952,0.360424533696936 +file1000850,file1001942,0.362632797518558 +file1001426,file1002143,0.363271909822866 +file1001304,file1001333,0.36404737582222 +file1000390,file1000518,0.364744579516818 +file1000830,file1002096,0.365897427529429 +file1000794,file1001856,0.365973692948894 +file1001266,file1001327,0.366395851089761 +file1001692,file1002352,0.36655953875445 +file1001564,file1001024,0.367284385415205 +file1001861,file1002050,0.36783497787384 +file1002066,file1002361,0.367964419694875 +file1001613,file1002087,0.368231014746024 +file1001931,file1000220,0.368847112914793 +file1000339,file1000554,0.370123905662701 +file1000754,file1002208,0.37031588493778 +file1001067,file1001956,0.371313060558732 +file1000101,file1001053,0.372141932838775 +file1002520,file1002409,0.372501194473693 +file1001459,file1001615,0.373295536945146 +file1001673,file1000508,0.376416667681519 +file1002201,file1001228,0.376680033570078 +file1000058,file1002449,0.376927627737029 +file1001748,file1001042,0.378067114701689 +file1001941,file1000376,0.37841176147662 +file1000801,file1002545,0.378423759459738 +file1000010,file1000535,0.38111194591455 +file1000882,file1002154,0.382223600234592 +file1001694,file1001297,0.382545161354354 +file1001992,file1002456,0.382664563820782 +file1001666,file1001773,0.382892588770697 +file1001629,file1002514,0.383417073960824 +file1002113,file1000738,0.385439884728523 +file1002221,file1000569,0.385903801966773 +file1002296,file1002117,0.387319754665673 +file1000693,file1001945,0.387855926202209 +file1001410,file1000223,0.391284037867147 +file1002071,file1001425,0.391497653794399 +file1002325,file1001259,0.391913965917762 +file1002430,file1001969,0.392256443856501 +file1002462,file1000708,0.393161981208355 +file1002358,file1001888,0.39427809496515 +file1000485,file1000753,0.395316199436001 +file1002357,file1001973,0.39564210237905 +file1002130,file1002041,0.395978941103639 +file1002569,file1000097,0.397496127623486 +file1002264,file1000148,0.397630184088734 +file1002381,file1001401,0.398105992102355 +file1000289,file1000585,0.399527637723015 +file1002368,file1001723,0.400243022234875 +file1002342,file1001319,0.400431803928825 +file1002170,file1001226,0.400632448147846 +file1001385,file1001758,0.400855988878681 +file1001732,file1002541,0.40091828863264 +file1001102,file1000762,0.400923140595936 +file1001470,file1000181,0.401353492516182 +file1000400,file1000884,0.401562860630016 +file1002293,file1002523,0.401800994807451 +file1000728,file1001654,0.402763341041675 +file1000582,file1001491,0.403451830806034 +file1000586,file1001521,0.403648293267187 +file1002287,file1001770,0.405194821414496 +file1000371,file1000159,0.405999000381268 +file1002356,file1002064,0.406519210876811 +file1000324,file1000590,0.407593694425997 +file1001622,file1001710,0.40759525378577 +file1002037,file1000403,0.407814136488744 +file1002444,file1000743,0.40943197761463 +file1001175,file1002088,0.410423663035312 +file1001391,file1000540,0.410854355646853 +file1002133,file1001186,0.411248429534111 +file1001229,file1001630,0.411355571792039 +file1002283,file1000402,0.411836769927671 +file1000627,file1000161,0.412089060388579 +file1001701,file1001402,0.412854774524637 +file1000795,file1000452,0.413448916432685 +file1000354,file1000947,0.41459642292987 +file1002043,file1002505,0.414863932355455 +file1001285,file1001113,0.418183757940871 +file1000170,file1001832,0.419441549204313 +file1002399,file1001500,0.419905873946513 +file1002439,file1000177,0.42054051043224 +file1001656,file1001217,0.420597020703942 +file1000296,file1000065,0.420845042251081 +file1000626,file1001623,0.42087934790355 +file1001767,file1000760,0.422315537515139 +file1000467,file1001246,0.422371268999111 +file1001033,file1000611,0.42425275873442 +file1002304,file1000221,0.425602179771197 +file1001737,file1001141,0.425716789218234 +file1001565,file1000559,0.426158561043574 +file1000249,file1000643,0.426541100077021 +file1002014,file1001109,0.426587840438723 +file1002006,file1000790,0.427829459781438 +file1000193,file1000750,0.428103808477214 +file1001993,file1001110,0.428186367615143 +file1002094,file1001814,0.428868578868176 +file1000098,file1001420,0.428968675677784 +file1000336,file1000211,0.430347427208789 +file1001498,file1002568,0.43204475404071 +file1001671,file1001106,0.432215802861284 +file1000426,file1002386,0.43283446816702 +file1001520,file1002481,0.434867670495723 +file1002189,file1001432,0.434924370194975 +file1001390,file1002554,0.435313848731387 +file1002166,file1001982,0.435387512979012 +file1001120,file1001006,0.435594761785839 +file1000149,file1001985,0.436289528591294 +file1001632,file1001008,0.436682374331417 +file1002567,file1001155,0.437221000601772 +file1000434,file1002195,0.438098100114814 +file1002532,file1001048,0.438500899539101 +file1001605,file1000927,0.438686659342641 +file1000479,file1000120,0.439587267995034 +file1002473,file1001388,0.439594997597548 +file1001108,file1002228,0.440528754793898 +file1002099,file1002056,0.440776843467602 +file1000191,file1002127,0.441114509542672 +file1000875,file1002494,0.441378135507993 +file1002161,file1000002,0.441912476744187 +file1002269,file1001220,0.442742296865228 +file1001295,file1001355,0.4435162405589 +file1001659,file1001023,0.444686151316673 +file1001857,file1001378,0.447500830900898 +file1001183,file1001370,0.447782748040587 +file1000428,file1000859,0.448328910257083 +file1000588,file1002227,0.448650488897259 +file1001098,file1000486,0.448862467740607 +file1001288,file1000408,0.450363676957042 +file1002097,file1001210,0.451126832474666 +file1000216,file1001082,0.451550143520946 +file1001746,file1001642,0.451781042569196 +file1002388,file1000204,0.451940333555972 +file1000021,file1000560,0.452234621797968 +file1000489,file1001545,0.452796032302523 +file1001116,file1000883,0.453096911915119 +file1001372,file1000561,0.45532542913335 +file1001276,file1000424,0.45534174289324 +file1000974,file1002098,0.455371894001872 +file1002566,file1002044,0.455937677517583 +file1000262,file1002046,0.456056330767294 +file1001619,file1001342,0.456559091350965 +file1000045,file1001616,0.457599407743834 +file1001468,file1002115,0.458095965024278 +file1001061,file1000233,0.460561351667266 +file1000558,file1000100,0.461094222462111 +file1000605,file1000691,0.461429521647285 +file1000640,file1000384,0.463383466503099 +file1000410,file1001358,0.463452482427773 +file1000851,file1001014,0.463558384057952 +file1001092,file1000138,0.463591264436099 +file1000061,file1002049,0.465778207162619 +file1001206,file1000983,0.466701211830884 +file1000256,file1000475,0.466865377968187 +file1002434,file1001387,0.467154181996099 +file1001036,file1000210,0.470404279499276 +file1001540,file1001860,0.472822271037545 +file1001244,file1001154,0.475076170733515 +file1000131,file1001526,0.475459563440874 +file1000180,file1002045,0.476814451110009 +file1001837,file1000637,0.478851985878026 +file1002425,file1001891,0.481451070031007 +file1001056,file1000682,0.482320170742015 +file1002276,file1000777,0.483452141843029 +file1001139,file1002544,0.487462418948035 +file1000548,file1001257,0.488098081542811 +file1000188,file1001286,0.488423105111001 +file1001879,file1000999,0.488449105381724 +file1001062,file1000231,0.48930683373911 +file1000040,file1001873,0.492070802214623 +file1002286,file1000066,0.493213986773381 +file1002474,file1002563,0.501584439120211 +file1000967,file1000563,0.502066261411662 +file1001307,file1002048,0.50460435259807 +file1000483,file1001699,0.511819026566198 +file1001528,file1000285,0.512629017841038 +file1001742,file1002371,0.513805213204644 +file1002397,file1000592,0.515406473057 +file1000069,file1000510,0.528220553613126 +file1001087,file1001300,0.536510449049583 +file1001991,file1000836,0.538145797125916 +file1001382,file1001806,0.538539506621535 +file1000111,file1001189,0.557690760784602 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/knee_data_split/singlecoil_val_split_less.csv b/MRI_recon/code/Frequency-Diffusion/FSMNet/knee_data_split/singlecoil_val_split_less.csv new file mode 100644 index 0000000000000000000000000000000000000000..a1cbac5537562063359f4ac3e0985de51cb989b2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/knee_data_split/singlecoil_val_split_less.csv @@ -0,0 +1,45 @@ +file1000323,file1002538,0.30754967523156 +file1001458,file1001566,0.310512744537048 +file1000885,file1001059,0.318226346221521 +file1000464,file1000196,0.321465466968232 +file1000314,file1000178,0.327505552363568 +file1001163,file1001289,0.328954963947692 +file1000033,file1001191,0.330925609207301 +file1000976,file1000990,0.344036229323198 +file1001930,file1001834,0.345994076497818 +file1002546,file1001344,0.351762252794677 +file1000277,file1001429,0.353297786572139 +file1001893,file1001262,0.358064285890878 +file1000926,file1002067,0.360639004205491 +file1001650,file1002002,0.362186928073579 +file1001184,file1001655,0.362592305723707 +file1001497,file1001338,0.365599407221502 +file1001202,file1001365,0.3844323497275 +file1001126,file1002340,0.388929627976346 +file1001339,file1000291,0.391300537691403 +file1002187,file1001862,0.39883786878841 +file1000041,file1000591,0.39896683485823 +file1001064,file1001850,0.399687813966601 +file1001331,file1002214,0.400340820924839 +file1000831,file1000528,0.403582747590964 +file1000769,file1000538,0.405298051020298 +file1000182,file1001968,0.407646172205036 +file1002382,file1001651,0.410749052045234 +file1000660,file1000476,0.415423894745454 +file1002570,file1001726,0.424622351472032 +file1001585,file1000858,0.426738511964108 +file1000190,file1000593,0.428080574167047 +file1001170,file1001090,0.429987089825525 +file1002252,file1001440,0.432038842370013 +file1000697,file1001144,0.432558506761396 +file1001077,file1000000,0.441922503777368 +file1001381,file1001119,0.455418270809002 +file1001759,file1001851,0.460824505737749 +file1000635,file1002389,0.465674267492171 +file1001668,file1001689,0.467330511330772 +file1001221,file1000818,0.469630000354232 +file1001298,file1002145,0.473526387887779 +file1001763,file1001938,0.47398893150184 +file1001444,file1000942,0.48507438696692 +file1000735,file1002007,0.496530240691134 +file1000477,file1000280,0.528508000547834 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/best_checkpoint.pth b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/best_checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..23d6a82907fb397cfecb49be3661d690847a722e --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/best_checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44abd53cb5ef9038368bc4ff013a461f51f3fbf0afc765db91b8db06126b9af8 +size 56614874 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/log.txt b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..232b191491ee4d18e6e06b9335c3bf05add1c248 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/log.txt @@ -0,0 +1,1105 @@ +[23:34:38.468] Namespace(root_path='/home/v-qichen3/MRI_recon/data/fastmri', MRIDOWN='4X', low_field_SNR=0, phase='train', gpu='0', exp='FSMNet_fastmri_4x', max_iterations=100000, batch_size=4, base_lr=0.0001, seed=1337, resume=None, relation_consistency='False', clip_grad='True', norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, snapshot_path='None', rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='random', CENTER_FRACTIONS=[0.08], ACCELERATIONS=[4], num_timesteps=5, image_size=320, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[23:36:13.017] iteration 100 [91.88 sec]: learning rate : 0.000100 loss : 1.107095 +[23:37:44.352] iteration 200 [183.21 sec]: learning rate : 0.000100 loss : 0.649038 +[23:39:15.312] iteration 300 [274.17 sec]: learning rate : 0.000100 loss : 0.726796 +[23:40:46.254] iteration 400 [365.12 sec]: learning rate : 0.000100 loss : 0.948658 +[23:42:17.161] iteration 500 [456.02 sec]: learning rate : 0.000100 loss : 0.498335 +[23:43:48.128] iteration 600 [546.99 sec]: learning rate : 0.000100 loss : 0.586920 +[23:45:19.133] iteration 700 [638.00 sec]: learning rate : 0.000100 loss : 1.196417 +[23:46:50.024] iteration 800 [728.89 sec]: learning rate : 0.000100 loss : 0.725201 +[23:48:20.968] iteration 900 [819.83 sec]: learning rate : 0.000100 loss : 0.678973 +[23:49:51.921] iteration 1000 [910.78 sec]: learning rate : 0.000100 loss : 0.674922 +[23:51:22.810] iteration 1100 [1001.67 sec]: learning rate : 0.000100 loss : 0.723473 +[23:52:53.753] iteration 1200 [1092.62 sec]: learning rate : 0.000100 loss : 0.898708 +[23:54:24.721] iteration 1300 [1183.58 sec]: learning rate : 0.000100 loss : 0.412233 +[23:55:55.624] iteration 1400 [1274.49 sec]: learning rate : 0.000100 loss : 0.507603 +[23:57:26.584] iteration 1500 [1365.45 sec]: learning rate : 0.000100 loss : 0.602146 +[23:58:57.522] iteration 1600 [1456.38 sec]: learning rate : 0.000100 loss : 0.640982 +[00:00:28.431] iteration 1700 [1547.29 sec]: learning rate : 0.000100 loss : 1.066198 +[00:01:59.430] iteration 1800 [1638.29 sec]: learning rate : 0.000100 loss : 0.464903 +[00:03:30.338] iteration 1900 [1729.20 sec]: learning rate : 0.000100 loss : 0.873845 +[00:05:01.319] iteration 2000 [1820.18 sec]: learning rate : 0.000100 loss : 1.273093 +[00:06:16.808] Epoch 0 Evaluation: +[00:07:07.392] average MSE: 0.052083611488342285 average PSNR: 28.166547251356686 average SSIM: 0.6882009233109193 +[00:07:23.132] iteration 2100 [15.67 sec]: learning rate : 0.000100 loss : 0.911856 +[00:08:54.020] iteration 2200 [106.56 sec]: learning rate : 0.000100 loss : 0.808895 +[00:10:25.064] iteration 2300 [197.60 sec]: learning rate : 0.000100 loss : 0.759366 +[00:11:56.027] iteration 2400 [288.57 sec]: learning rate : 0.000100 loss : 0.437011 +[00:13:26.926] iteration 2500 [379.47 sec]: learning rate : 0.000100 loss : 0.648811 +[00:14:57.848] iteration 2600 [470.39 sec]: learning rate : 0.000100 loss : 0.581172 +[00:16:28.777] iteration 2700 [561.32 sec]: learning rate : 0.000100 loss : 0.731982 +[00:17:59.694] iteration 2800 [652.23 sec]: learning rate : 0.000100 loss : 0.667945 +[00:19:30.666] iteration 2900 [743.21 sec]: learning rate : 0.000100 loss : 0.663000 +[00:21:01.623] iteration 3000 [834.16 sec]: learning rate : 0.000100 loss : 0.931858 +[00:22:32.515] iteration 3100 [925.05 sec]: learning rate : 0.000100 loss : 0.706683 +[00:24:03.467] iteration 3200 [1016.01 sec]: learning rate : 0.000100 loss : 0.694515 +[00:25:34.375] iteration 3300 [1106.92 sec]: learning rate : 0.000100 loss : 0.493996 +[00:27:05.281] iteration 3400 [1197.82 sec]: learning rate : 0.000100 loss : 0.852053 +[00:28:36.226] iteration 3500 [1288.77 sec]: learning rate : 0.000100 loss : 0.937622 +[00:30:07.118] iteration 3600 [1379.66 sec]: learning rate : 0.000100 loss : 0.973131 +[00:31:38.057] iteration 3700 [1470.60 sec]: learning rate : 0.000100 loss : 1.168276 +[00:33:09.013] iteration 3800 [1561.55 sec]: learning rate : 0.000100 loss : 0.580963 +[00:34:39.923] iteration 3900 [1652.46 sec]: learning rate : 0.000100 loss : 0.584199 +[00:36:10.881] iteration 4000 [1743.42 sec]: learning rate : 0.000100 loss : 0.914120 +[00:37:41.881] iteration 4100 [1834.42 sec]: learning rate : 0.000100 loss : 0.774764 +[00:38:41.849] Epoch 1 Evaluation: +[00:39:33.847] average MSE: 0.0523013174533844 average PSNR: 28.15989446084589 average SSIM: 0.6990111970265264 +[00:40:05.059] iteration 4200 [31.14 sec]: learning rate : 0.000100 loss : 0.606452 +[00:41:36.068] iteration 4300 [122.15 sec]: learning rate : 0.000100 loss : 0.439096 +[00:43:07.063] iteration 4400 [213.15 sec]: learning rate : 0.000100 loss : 0.692683 +[00:44:37.973] iteration 4500 [304.06 sec]: learning rate : 0.000100 loss : 0.591143 +[00:46:08.899] iteration 4600 [394.98 sec]: learning rate : 0.000100 loss : 0.588003 +[00:47:39.800] iteration 4700 [485.88 sec]: learning rate : 0.000100 loss : 0.534172 +[00:49:10.769] iteration 4800 [576.86 sec]: learning rate : 0.000100 loss : 0.851590 +[00:50:41.718] iteration 4900 [667.80 sec]: learning rate : 0.000100 loss : 0.561469 +[00:52:12.615] iteration 5000 [758.70 sec]: learning rate : 0.000100 loss : 0.679012 +[00:53:43.570] iteration 5100 [849.66 sec]: learning rate : 0.000100 loss : 0.859858 +[00:55:14.470] iteration 5200 [940.56 sec]: learning rate : 0.000100 loss : 0.551918 +[00:56:45.422] iteration 5300 [1031.51 sec]: learning rate : 0.000100 loss : 0.593098 +[00:58:16.400] iteration 5400 [1122.48 sec]: learning rate : 0.000100 loss : 0.417584 +[00:59:47.309] iteration 5500 [1213.40 sec]: learning rate : 0.000100 loss : 0.698338 +[01:01:18.280] iteration 5600 [1304.37 sec]: learning rate : 0.000100 loss : 0.394753 +[01:02:49.210] iteration 5700 [1395.30 sec]: learning rate : 0.000100 loss : 0.796182 +[01:04:20.116] iteration 5800 [1486.20 sec]: learning rate : 0.000100 loss : 0.845034 +[01:05:51.076] iteration 5900 [1577.16 sec]: learning rate : 0.000100 loss : 0.550440 +[01:07:21.962] iteration 6000 [1668.05 sec]: learning rate : 0.000100 loss : 0.672448 +[01:08:52.907] iteration 6100 [1758.99 sec]: learning rate : 0.000100 loss : 0.495482 +[01:10:23.893] iteration 6200 [1849.98 sec]: learning rate : 0.000100 loss : 0.455312 +[01:11:08.438] Epoch 2 Evaluation: +[01:12:01.518] average MSE: 0.049506932497024536 average PSNR: 28.476793359819066 average SSIM: 0.7068383279023646 +[01:12:48.183] iteration 6300 [46.60 sec]: learning rate : 0.000100 loss : 1.123563 +[01:14:19.209] iteration 6400 [137.62 sec]: learning rate : 0.000100 loss : 0.686327 +[01:15:50.148] iteration 6500 [228.56 sec]: learning rate : 0.000100 loss : 0.702656 +[01:17:21.056] iteration 6600 [319.47 sec]: learning rate : 0.000100 loss : 0.479897 +[01:18:52.012] iteration 6700 [410.43 sec]: learning rate : 0.000100 loss : 0.667063 +[01:20:22.926] iteration 6800 [501.34 sec]: learning rate : 0.000100 loss : 0.553630 +[01:21:53.849] iteration 6900 [592.27 sec]: learning rate : 0.000100 loss : 0.837110 +[01:23:24.827] iteration 7000 [683.24 sec]: learning rate : 0.000100 loss : 0.359016 +[01:24:55.718] iteration 7100 [774.13 sec]: learning rate : 0.000100 loss : 0.662692 +[01:26:26.698] iteration 7200 [865.11 sec]: learning rate : 0.000100 loss : 0.366699 +[01:27:57.676] iteration 7300 [956.09 sec]: learning rate : 0.000100 loss : 0.451207 +[01:29:28.561] iteration 7400 [1046.98 sec]: learning rate : 0.000100 loss : 0.497083 +[01:30:59.490] iteration 7500 [1137.90 sec]: learning rate : 0.000100 loss : 0.348612 +[01:32:30.439] iteration 7600 [1228.92 sec]: learning rate : 0.000100 loss : 0.521196 +[01:34:01.344] iteration 7700 [1319.76 sec]: learning rate : 0.000100 loss : 0.507766 +[01:35:32.334] iteration 7800 [1410.75 sec]: learning rate : 0.000100 loss : 0.267644 +[01:37:03.220] iteration 7900 [1501.63 sec]: learning rate : 0.000100 loss : 0.426973 +[01:38:34.155] iteration 8000 [1592.57 sec]: learning rate : 0.000100 loss : 0.478598 +[01:40:05.086] iteration 8100 [1683.50 sec]: learning rate : 0.000100 loss : 0.526327 +[01:41:35.993] iteration 8200 [1774.41 sec]: learning rate : 0.000100 loss : 0.432233 +[01:43:06.924] iteration 8300 [1865.34 sec]: learning rate : 0.000100 loss : 0.453909 +[01:43:35.984] Epoch 3 Evaluation: +[01:44:26.724] average MSE: 0.046711266040802 average PSNR: 28.766678547019918 average SSIM: 0.7103327701265335 +[01:45:28.949] iteration 8400 [62.16 sec]: learning rate : 0.000100 loss : 0.909527 +[01:46:59.858] iteration 8500 [153.07 sec]: learning rate : 0.000100 loss : 0.351569 +[01:48:30.826] iteration 8600 [244.03 sec]: learning rate : 0.000100 loss : 0.489406 +[01:50:01.765] iteration 8700 [334.97 sec]: learning rate : 0.000100 loss : 0.249047 +[01:51:32.660] iteration 8800 [425.87 sec]: learning rate : 0.000100 loss : 0.416822 +[01:53:03.610] iteration 8900 [516.82 sec]: learning rate : 0.000100 loss : 0.432970 +[01:54:34.566] iteration 9000 [607.77 sec]: learning rate : 0.000100 loss : 0.353625 +[01:56:05.453] iteration 9100 [698.66 sec]: learning rate : 0.000100 loss : 0.463427 +[01:57:36.359] iteration 9200 [789.57 sec]: learning rate : 0.000100 loss : 0.887383 +[01:59:07.272] iteration 9300 [880.48 sec]: learning rate : 0.000100 loss : 0.472291 +[02:00:38.214] iteration 9400 [971.42 sec]: learning rate : 0.000100 loss : 0.413563 +[02:02:09.176] iteration 9500 [1062.38 sec]: learning rate : 0.000100 loss : 0.606498 +[02:03:40.068] iteration 9600 [1153.28 sec]: learning rate : 0.000100 loss : 0.373390 +[02:05:11.021] iteration 9700 [1244.23 sec]: learning rate : 0.000100 loss : 0.613581 +[02:06:41.967] iteration 9800 [1335.18 sec]: learning rate : 0.000100 loss : 0.745537 +[02:08:12.858] iteration 9900 [1426.07 sec]: learning rate : 0.000100 loss : 0.528753 +[02:09:43.824] iteration 10000 [1517.03 sec]: learning rate : 0.000100 loss : 0.349934 +[02:11:14.749] iteration 10100 [1607.96 sec]: learning rate : 0.000100 loss : 0.347502 +[02:12:45.727] iteration 10200 [1698.94 sec]: learning rate : 0.000100 loss : 0.514146 +[02:14:16.688] iteration 10300 [1789.90 sec]: learning rate : 0.000100 loss : 0.576310 +[02:15:47.590] iteration 10400 [1880.80 sec]: learning rate : 0.000100 loss : 0.638786 +[02:16:01.215] Epoch 4 Evaluation: +[02:16:53.613] average MSE: 0.04393121600151062 average PSNR: 29.058637837835374 average SSIM: 0.7157130774470509 +[02:18:11.157] iteration 10500 [77.48 sec]: learning rate : 0.000100 loss : 0.386306 +[02:19:42.180] iteration 10600 [168.50 sec]: learning rate : 0.000100 loss : 0.674246 +[02:21:13.082] iteration 10700 [259.40 sec]: learning rate : 0.000100 loss : 0.479376 +[02:22:44.002] iteration 10800 [350.32 sec]: learning rate : 0.000100 loss : 0.731016 +[02:24:14.908] iteration 10900 [441.23 sec]: learning rate : 0.000100 loss : 0.326050 +[02:25:45.883] iteration 11000 [532.20 sec]: learning rate : 0.000100 loss : 0.676414 +[02:27:16.884] iteration 11100 [623.20 sec]: learning rate : 0.000100 loss : 0.584250 +[02:28:47.784] iteration 11200 [714.10 sec]: learning rate : 0.000100 loss : 0.325216 +[02:30:18.755] iteration 11300 [805.07 sec]: learning rate : 0.000100 loss : 0.290992 +[02:31:49.684] iteration 11400 [896.00 sec]: learning rate : 0.000100 loss : 0.507458 +[02:33:20.623] iteration 11500 [986.94 sec]: learning rate : 0.000100 loss : 0.544806 +[02:34:51.588] iteration 11600 [1077.91 sec]: learning rate : 0.000100 loss : 0.623896 +[02:36:22.535] iteration 11700 [1168.85 sec]: learning rate : 0.000100 loss : 0.746427 +[02:37:53.434] iteration 11800 [1259.75 sec]: learning rate : 0.000100 loss : 0.416399 +[02:39:24.383] iteration 11900 [1350.70 sec]: learning rate : 0.000100 loss : 0.530423 +[02:40:55.280] iteration 12000 [1441.60 sec]: learning rate : 0.000100 loss : 0.298308 +[02:42:26.257] iteration 12100 [1532.58 sec]: learning rate : 0.000100 loss : 0.870794 +[02:43:57.210] iteration 12200 [1623.53 sec]: learning rate : 0.000100 loss : 0.617099 +[02:45:28.095] iteration 12300 [1714.41 sec]: learning rate : 0.000100 loss : 0.447822 +[02:46:59.051] iteration 12400 [1805.37 sec]: learning rate : 0.000100 loss : 0.261168 +[02:48:28.180] Epoch 5 Evaluation: +[02:49:21.185] average MSE: 0.04385746642947197 average PSNR: 29.0923527805322 average SSIM: 0.7152118551529129 +[02:49:23.293] iteration 12500 [2.04 sec]: learning rate : 0.000100 loss : 0.368382 +[02:50:54.205] iteration 12600 [92.95 sec]: learning rate : 0.000100 loss : 0.772542 +[02:52:25.204] iteration 12700 [183.95 sec]: learning rate : 0.000100 loss : 0.532949 +[02:53:56.099] iteration 12800 [274.85 sec]: learning rate : 0.000100 loss : 0.494408 +[02:55:27.048] iteration 12900 [365.80 sec]: learning rate : 0.000100 loss : 0.472541 +[02:56:57.960] iteration 13000 [456.71 sec]: learning rate : 0.000100 loss : 0.616311 +[02:58:28.865] iteration 13100 [547.61 sec]: learning rate : 0.000100 loss : 1.004819 +[02:59:59.800] iteration 13200 [638.55 sec]: learning rate : 0.000100 loss : 0.308819 +[03:01:30.750] iteration 13300 [729.50 sec]: learning rate : 0.000100 loss : 0.536576 +[03:03:01.664] iteration 13400 [820.41 sec]: learning rate : 0.000100 loss : 0.898157 +[03:04:32.622] iteration 13500 [911.37 sec]: learning rate : 0.000100 loss : 0.758978 +[03:06:03.598] iteration 13600 [1002.35 sec]: learning rate : 0.000100 loss : 0.424201 +[03:07:34.489] iteration 13700 [1093.24 sec]: learning rate : 0.000100 loss : 0.494892 +[03:09:05.490] iteration 13800 [1184.24 sec]: learning rate : 0.000100 loss : 0.792201 +[03:10:36.423] iteration 13900 [1275.17 sec]: learning rate : 0.000100 loss : 0.353038 +[03:12:07.380] iteration 14000 [1366.13 sec]: learning rate : 0.000100 loss : 0.317470 +[03:13:38.310] iteration 14100 [1457.06 sec]: learning rate : 0.000100 loss : 0.532166 +[03:15:09.222] iteration 14200 [1547.97 sec]: learning rate : 0.000100 loss : 0.598017 +[03:16:40.215] iteration 14300 [1638.96 sec]: learning rate : 0.000100 loss : 0.792128 +[03:18:11.172] iteration 14400 [1729.92 sec]: learning rate : 0.000100 loss : 0.562804 +[03:19:42.085] iteration 14500 [1820.83 sec]: learning rate : 0.000100 loss : 0.527301 +[03:20:55.756] Epoch 6 Evaluation: +[03:21:47.374] average MSE: 0.043408945202827454 average PSNR: 29.198706646259044 average SSIM: 0.7206405223147119 +[03:22:04.921] iteration 14600 [17.48 sec]: learning rate : 0.000100 loss : 0.421454 +[03:23:35.908] iteration 14700 [108.47 sec]: learning rate : 0.000100 loss : 0.461820 +[03:25:06.813] iteration 14800 [199.37 sec]: learning rate : 0.000100 loss : 0.518653 +[03:26:37.788] iteration 14900 [290.35 sec]: learning rate : 0.000100 loss : 0.505102 +[03:28:08.750] iteration 15000 [381.31 sec]: learning rate : 0.000100 loss : 0.354999 +[03:29:39.628] iteration 15100 [472.19 sec]: learning rate : 0.000100 loss : 0.652035 +[03:31:10.602] iteration 15200 [563.16 sec]: learning rate : 0.000100 loss : 0.440540 +[03:32:41.496] iteration 15300 [654.05 sec]: learning rate : 0.000100 loss : 0.611976 +[03:34:12.445] iteration 15400 [745.00 sec]: learning rate : 0.000100 loss : 0.297308 +[03:35:43.405] iteration 15500 [835.96 sec]: learning rate : 0.000100 loss : 0.444464 +[03:37:14.315] iteration 15600 [926.87 sec]: learning rate : 0.000100 loss : 0.419943 +[03:38:45.255] iteration 15700 [1017.81 sec]: learning rate : 0.000100 loss : 0.366550 +[03:40:16.185] iteration 15800 [1108.74 sec]: learning rate : 0.000100 loss : 0.499091 +[03:41:47.086] iteration 15900 [1199.65 sec]: learning rate : 0.000100 loss : 0.561861 +[03:43:18.035] iteration 16000 [1290.60 sec]: learning rate : 0.000100 loss : 0.751011 +[03:44:48.962] iteration 16100 [1381.52 sec]: learning rate : 0.000100 loss : 0.640022 +[03:46:19.932] iteration 16200 [1472.49 sec]: learning rate : 0.000100 loss : 0.638821 +[03:47:50.884] iteration 16300 [1563.44 sec]: learning rate : 0.000100 loss : 0.295250 +[03:49:21.776] iteration 16400 [1654.33 sec]: learning rate : 0.000100 loss : 0.713771 +[03:50:52.727] iteration 16500 [1745.29 sec]: learning rate : 0.000100 loss : 0.792403 +[03:52:23.685] iteration 16600 [1836.24 sec]: learning rate : 0.000100 loss : 0.481343 +[03:53:21.831] Epoch 7 Evaluation: +[03:54:13.386] average MSE: 0.04207322373986244 average PSNR: 29.360627408818956 average SSIM: 0.7237924629224542 +[03:54:46.376] iteration 16700 [32.92 sec]: learning rate : 0.000100 loss : 0.316321 +[03:56:17.374] iteration 16800 [123.92 sec]: learning rate : 0.000100 loss : 0.440553 +[03:57:48.306] iteration 16900 [214.85 sec]: learning rate : 0.000100 loss : 0.421092 +[03:59:19.211] iteration 17000 [305.76 sec]: learning rate : 0.000100 loss : 0.307229 +[04:00:50.152] iteration 17100 [396.70 sec]: learning rate : 0.000100 loss : 0.412982 +[04:02:21.103] iteration 17200 [487.65 sec]: learning rate : 0.000100 loss : 0.452248 +[04:03:52.004] iteration 17300 [578.55 sec]: learning rate : 0.000100 loss : 0.917094 +[04:05:22.963] iteration 17400 [669.51 sec]: learning rate : 0.000100 loss : 0.681891 +[04:06:53.851] iteration 17500 [760.40 sec]: learning rate : 0.000100 loss : 0.635233 +[04:08:24.745] iteration 17600 [851.29 sec]: learning rate : 0.000100 loss : 0.465233 +[04:09:55.686] iteration 17700 [942.23 sec]: learning rate : 0.000100 loss : 0.436869 +[04:11:26.575] iteration 17800 [1033.12 sec]: learning rate : 0.000100 loss : 0.584993 +[04:12:57.528] iteration 17900 [1124.08 sec]: learning rate : 0.000100 loss : 0.491723 +[04:14:28.533] iteration 18000 [1215.08 sec]: learning rate : 0.000100 loss : 0.594206 +[04:15:59.430] iteration 18100 [1305.98 sec]: learning rate : 0.000100 loss : 0.529270 +[04:17:30.373] iteration 18200 [1396.92 sec]: learning rate : 0.000100 loss : 0.565421 +[04:19:01.267] iteration 18300 [1487.81 sec]: learning rate : 0.000100 loss : 0.832999 +[04:20:32.224] iteration 18400 [1578.77 sec]: learning rate : 0.000100 loss : 0.587193 +[04:22:03.224] iteration 18500 [1669.77 sec]: learning rate : 0.000100 loss : 0.574064 +[04:23:34.126] iteration 18600 [1760.67 sec]: learning rate : 0.000100 loss : 0.403993 +[04:25:05.108] iteration 18700 [1851.65 sec]: learning rate : 0.000100 loss : 0.576969 +[04:25:47.805] Epoch 8 Evaluation: +[04:26:38.081] average MSE: 0.04205470159649849 average PSNR: 29.343005159644587 average SSIM: 0.726140998554615 +[04:27:26.660] iteration 18800 [48.51 sec]: learning rate : 0.000100 loss : 0.475493 +[04:28:57.547] iteration 18900 [139.40 sec]: learning rate : 0.000100 loss : 0.377687 +[04:30:28.461] iteration 19000 [230.31 sec]: learning rate : 0.000100 loss : 0.356015 +[04:31:59.348] iteration 19100 [321.20 sec]: learning rate : 0.000100 loss : 0.299239 +[04:33:30.277] iteration 19200 [412.13 sec]: learning rate : 0.000100 loss : 0.612497 +[04:35:01.239] iteration 19300 [503.09 sec]: learning rate : 0.000100 loss : 0.629012 +[04:36:32.130] iteration 19400 [593.98 sec]: learning rate : 0.000100 loss : 0.476803 +[04:38:03.053] iteration 19500 [684.90 sec]: learning rate : 0.000100 loss : 0.318801 +[04:39:34.023] iteration 19600 [775.88 sec]: learning rate : 0.000100 loss : 0.621748 +[04:41:04.921] iteration 19700 [866.77 sec]: learning rate : 0.000100 loss : 0.403969 +[04:42:35.876] iteration 19800 [957.73 sec]: learning rate : 0.000100 loss : 0.442042 +[04:44:06.846] iteration 19900 [1048.76 sec]: learning rate : 0.000100 loss : 0.599009 +[04:45:37.757] iteration 20000 [1139.61 sec]: learning rate : 0.000025 loss : 0.629008 +[04:45:37.910] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_20000.pth +[04:47:08.856] iteration 20100 [1230.71 sec]: learning rate : 0.000050 loss : 0.566460 +[04:48:39.734] iteration 20200 [1321.59 sec]: learning rate : 0.000050 loss : 0.569923 +[04:50:10.632] iteration 20300 [1412.48 sec]: learning rate : 0.000050 loss : 0.504588 +[04:51:41.522] iteration 20400 [1503.37 sec]: learning rate : 0.000050 loss : 0.595861 +[04:53:12.409] iteration 20500 [1594.26 sec]: learning rate : 0.000050 loss : 0.353715 +[04:54:43.339] iteration 20600 [1685.19 sec]: learning rate : 0.000050 loss : 0.491937 +[04:56:14.274] iteration 20700 [1776.12 sec]: learning rate : 0.000050 loss : 0.748876 +[04:57:45.169] iteration 20800 [1867.02 sec]: learning rate : 0.000050 loss : 0.369658 +[04:58:12.488] Epoch 9 Evaluation: +[04:59:04.946] average MSE: 0.04063296690583229 average PSNR: 29.535049624129858 average SSIM: 0.7283188830324696 +[05:00:08.874] iteration 20900 [63.86 sec]: learning rate : 0.000050 loss : 0.476149 +[05:01:39.869] iteration 21000 [154.86 sec]: learning rate : 0.000050 loss : 0.485305 +[05:03:10.766] iteration 21100 [245.75 sec]: learning rate : 0.000050 loss : 0.388947 +[05:04:41.687] iteration 21200 [336.67 sec]: learning rate : 0.000050 loss : 0.299005 +[05:06:12.621] iteration 21300 [427.61 sec]: learning rate : 0.000050 loss : 0.442499 +[05:07:43.601] iteration 21400 [518.59 sec]: learning rate : 0.000050 loss : 0.477599 +[05:09:14.528] iteration 21500 [609.51 sec]: learning rate : 0.000050 loss : 0.741858 +[05:10:45.413] iteration 21600 [700.40 sec]: learning rate : 0.000050 loss : 0.365357 +[05:12:16.345] iteration 21700 [791.33 sec]: learning rate : 0.000050 loss : 0.258487 +[05:13:47.296] iteration 21800 [882.28 sec]: learning rate : 0.000050 loss : 0.367111 +[05:15:18.192] iteration 21900 [973.18 sec]: learning rate : 0.000050 loss : 0.670486 +[05:16:49.169] iteration 22000 [1064.16 sec]: learning rate : 0.000050 loss : 0.388939 +[05:18:20.075] iteration 22100 [1155.06 sec]: learning rate : 0.000050 loss : 0.457944 +[05:19:50.977] iteration 22200 [1245.96 sec]: learning rate : 0.000050 loss : 0.618238 +[05:21:21.952] iteration 22300 [1336.94 sec]: learning rate : 0.000050 loss : 0.392950 +[05:22:52.906] iteration 22400 [1427.89 sec]: learning rate : 0.000050 loss : 0.411173 +[05:24:23.803] iteration 22500 [1518.79 sec]: learning rate : 0.000050 loss : 0.728031 +[05:25:54.752] iteration 22600 [1609.74 sec]: learning rate : 0.000050 loss : 0.438612 +[05:27:25.651] iteration 22700 [1700.64 sec]: learning rate : 0.000050 loss : 0.604349 +[05:28:56.620] iteration 22800 [1791.61 sec]: learning rate : 0.000050 loss : 0.318898 +[05:30:27.576] iteration 22900 [1882.56 sec]: learning rate : 0.000050 loss : 0.402081 +[05:30:39.378] Epoch 10 Evaluation: +[05:31:32.044] average MSE: 0.04036292806267738 average PSNR: 29.579878758451684 average SSIM: 0.729254196281737 +[05:32:51.447] iteration 23000 [79.34 sec]: learning rate : 0.000050 loss : 0.466352 +[05:34:22.437] iteration 23100 [170.33 sec]: learning rate : 0.000050 loss : 0.630977 +[05:35:53.369] iteration 23200 [261.26 sec]: learning rate : 0.000050 loss : 0.446462 +[05:37:24.253] iteration 23300 [352.14 sec]: learning rate : 0.000050 loss : 0.626567 +[05:38:55.196] iteration 23400 [443.08 sec]: learning rate : 0.000050 loss : 0.348047 +[05:40:26.146] iteration 23500 [534.03 sec]: learning rate : 0.000050 loss : 0.690197 +[05:41:57.067] iteration 23600 [624.95 sec]: learning rate : 0.000050 loss : 0.442779 +[05:43:28.016] iteration 23700 [715.90 sec]: learning rate : 0.000050 loss : 0.293989 +[05:44:58.967] iteration 23800 [806.86 sec]: learning rate : 0.000050 loss : 0.493827 +[05:46:29.865] iteration 23900 [897.75 sec]: learning rate : 0.000050 loss : 0.469426 +[05:48:00.808] iteration 24000 [988.70 sec]: learning rate : 0.000050 loss : 0.710362 +[05:49:31.695] iteration 24100 [1079.58 sec]: learning rate : 0.000050 loss : 0.482910 +[05:51:02.659] iteration 24200 [1170.55 sec]: learning rate : 0.000050 loss : 1.045676 +[05:52:33.595] iteration 24300 [1261.48 sec]: learning rate : 0.000050 loss : 0.426108 +[05:54:04.544] iteration 24400 [1352.43 sec]: learning rate : 0.000050 loss : 0.471525 +[05:55:35.493] iteration 24500 [1443.38 sec]: learning rate : 0.000050 loss : 0.282425 +[05:57:06.436] iteration 24600 [1534.32 sec]: learning rate : 0.000050 loss : 0.980447 +[05:58:37.320] iteration 24700 [1625.21 sec]: learning rate : 0.000050 loss : 0.307383 +[06:00:08.255] iteration 24800 [1716.15 sec]: learning rate : 0.000050 loss : 0.656804 +[06:01:39.160] iteration 24900 [1807.05 sec]: learning rate : 0.000050 loss : 0.405413 +[06:03:06.442] Epoch 11 Evaluation: +[06:03:57.226] average MSE: 0.04019030183553696 average PSNR: 29.614837414668017 average SSIM: 0.7302369752267289 +[06:04:01.149] iteration 25000 [3.86 sec]: learning rate : 0.000050 loss : 0.387788 +[06:05:32.170] iteration 25100 [94.88 sec]: learning rate : 0.000050 loss : 0.673012 +[06:07:03.069] iteration 25200 [185.78 sec]: learning rate : 0.000050 loss : 0.379475 +[06:08:34.021] iteration 25300 [276.73 sec]: learning rate : 0.000050 loss : 0.308547 +[06:10:04.985] iteration 25400 [367.69 sec]: learning rate : 0.000050 loss : 0.521117 +[06:11:35.934] iteration 25500 [458.64 sec]: learning rate : 0.000050 loss : 0.764554 +[06:13:06.933] iteration 25600 [549.64 sec]: learning rate : 0.000050 loss : 0.814193 +[06:14:37.922] iteration 25700 [640.63 sec]: learning rate : 0.000050 loss : 0.437331 +[06:16:08.829] iteration 25800 [731.53 sec]: learning rate : 0.000050 loss : 0.400501 +[06:17:39.789] iteration 25900 [822.50 sec]: learning rate : 0.000050 loss : 0.817996 +[06:19:10.711] iteration 26000 [913.42 sec]: learning rate : 0.000050 loss : 0.654725 +[06:20:41.691] iteration 26100 [1004.40 sec]: learning rate : 0.000050 loss : 0.515304 +[06:22:12.668] iteration 26200 [1095.38 sec]: learning rate : 0.000050 loss : 0.711710 +[06:23:43.577] iteration 26300 [1186.28 sec]: learning rate : 0.000050 loss : 0.414067 +[06:25:14.508] iteration 26400 [1277.21 sec]: learning rate : 0.000050 loss : 0.460812 +[06:26:45.447] iteration 26500 [1368.15 sec]: learning rate : 0.000050 loss : 0.408211 +[06:28:16.330] iteration 26600 [1459.04 sec]: learning rate : 0.000050 loss : 0.373600 +[06:29:47.264] iteration 26700 [1549.97 sec]: learning rate : 0.000050 loss : 0.376957 +[06:31:18.152] iteration 26800 [1640.86 sec]: learning rate : 0.000050 loss : 0.585041 +[06:32:49.099] iteration 26900 [1731.80 sec]: learning rate : 0.000050 loss : 0.466475 +[06:34:20.058] iteration 27000 [1822.76 sec]: learning rate : 0.000050 loss : 0.829205 +[06:35:31.843] Epoch 12 Evaluation: +[06:36:23.795] average MSE: 0.03992251679301262 average PSNR: 29.63436785662357 average SSIM: 0.7308197477903112 +[06:36:43.196] iteration 27100 [19.33 sec]: learning rate : 0.000050 loss : 0.394265 +[06:38:14.224] iteration 27200 [110.36 sec]: learning rate : 0.000050 loss : 0.467709 +[06:39:45.202] iteration 27300 [201.34 sec]: learning rate : 0.000050 loss : 0.589970 +[06:41:16.112] iteration 27400 [292.25 sec]: learning rate : 0.000050 loss : 0.255207 +[06:42:47.044] iteration 27500 [383.18 sec]: learning rate : 0.000050 loss : 0.416692 +[06:44:18.005] iteration 27600 [474.14 sec]: learning rate : 0.000050 loss : 0.691401 +[06:45:48.903] iteration 27700 [565.04 sec]: learning rate : 0.000050 loss : 0.444303 +[06:47:19.857] iteration 27800 [656.00 sec]: learning rate : 0.000050 loss : 0.494278 +[06:48:50.768] iteration 27900 [746.91 sec]: learning rate : 0.000050 loss : 0.957615 +[06:50:21.695] iteration 28000 [837.83 sec]: learning rate : 0.000050 loss : 0.384617 +[06:51:52.664] iteration 28100 [928.80 sec]: learning rate : 0.000050 loss : 0.533876 +[06:53:23.575] iteration 28200 [1019.71 sec]: learning rate : 0.000050 loss : 0.853514 +[06:54:54.550] iteration 28300 [1110.69 sec]: learning rate : 0.000050 loss : 0.424473 +[06:56:25.522] iteration 28400 [1201.66 sec]: learning rate : 0.000050 loss : 0.561189 +[06:57:56.415] iteration 28500 [1292.55 sec]: learning rate : 0.000050 loss : 0.562364 +[06:59:27.348] iteration 28600 [1383.49 sec]: learning rate : 0.000050 loss : 0.458366 +[07:00:58.256] iteration 28700 [1474.39 sec]: learning rate : 0.000050 loss : 0.478985 +[07:02:29.138] iteration 28800 [1565.28 sec]: learning rate : 0.000050 loss : 0.634065 +[07:04:00.069] iteration 28900 [1656.21 sec]: learning rate : 0.000050 loss : 0.347051 +[07:05:30.978] iteration 29000 [1747.12 sec]: learning rate : 0.000050 loss : 0.620922 +[07:07:01.973] iteration 29100 [1838.11 sec]: learning rate : 0.000050 loss : 0.517713 +[07:07:58.292] Epoch 13 Evaluation: +[07:08:49.491] average MSE: 0.039458807557821274 average PSNR: 29.686848516519575 average SSIM: 0.7322088357626796 +[07:09:24.540] iteration 29200 [34.98 sec]: learning rate : 0.000050 loss : 0.775967 +[07:10:55.426] iteration 29300 [125.87 sec]: learning rate : 0.000050 loss : 0.491150 +[07:12:26.387] iteration 29400 [216.83 sec]: learning rate : 0.000050 loss : 0.322735 +[07:13:57.291] iteration 29500 [307.73 sec]: learning rate : 0.000050 loss : 0.610651 +[07:15:28.195] iteration 29600 [398.64 sec]: learning rate : 0.000050 loss : 0.284070 +[07:16:59.157] iteration 29700 [489.60 sec]: learning rate : 0.000050 loss : 0.653319 +[07:18:30.106] iteration 29800 [580.55 sec]: learning rate : 0.000050 loss : 0.456362 +[07:20:00.993] iteration 29900 [671.44 sec]: learning rate : 0.000050 loss : 0.606682 +[07:21:31.901] iteration 30000 [762.34 sec]: learning rate : 0.000050 loss : 0.498909 +[07:23:02.808] iteration 30100 [853.25 sec]: learning rate : 0.000050 loss : 0.928487 +[07:24:33.780] iteration 30200 [944.22 sec]: learning rate : 0.000050 loss : 0.412615 +[07:26:04.719] iteration 30300 [1035.16 sec]: learning rate : 0.000050 loss : 0.610963 +[07:27:35.638] iteration 30400 [1126.08 sec]: learning rate : 0.000050 loss : 0.558136 +[07:29:06.571] iteration 30500 [1217.01 sec]: learning rate : 0.000050 loss : 0.299014 +[07:30:37.477] iteration 30600 [1307.92 sec]: learning rate : 0.000050 loss : 0.508916 +[07:32:08.425] iteration 30700 [1398.87 sec]: learning rate : 0.000050 loss : 0.620398 +[07:33:39.368] iteration 30800 [1489.81 sec]: learning rate : 0.000050 loss : 0.530738 +[07:35:10.259] iteration 30900 [1580.70 sec]: learning rate : 0.000050 loss : 0.389487 +[07:36:41.165] iteration 31000 [1671.61 sec]: learning rate : 0.000050 loss : 0.642807 +[07:38:12.093] iteration 31100 [1762.53 sec]: learning rate : 0.000050 loss : 0.525595 +[07:39:42.980] iteration 31200 [1853.42 sec]: learning rate : 0.000050 loss : 0.527612 +[07:40:23.911] Epoch 14 Evaluation: +[07:41:14.510] average MSE: 0.03945602849125862 average PSNR: 29.70087088072418 average SSIM: 0.731548484752393 +[07:42:04.813] iteration 31300 [50.23 sec]: learning rate : 0.000050 loss : 0.414191 +[07:43:35.802] iteration 31400 [141.29 sec]: learning rate : 0.000050 loss : 0.402359 +[07:45:06.686] iteration 31500 [232.11 sec]: learning rate : 0.000050 loss : 0.280064 +[07:46:37.619] iteration 31600 [323.04 sec]: learning rate : 0.000050 loss : 0.318800 +[07:48:08.522] iteration 31700 [413.94 sec]: learning rate : 0.000050 loss : 0.295089 +[07:49:39.496] iteration 31800 [504.92 sec]: learning rate : 0.000050 loss : 0.408286 +[07:51:10.491] iteration 31900 [595.91 sec]: learning rate : 0.000050 loss : 0.566664 +[07:52:41.381] iteration 32000 [686.80 sec]: learning rate : 0.000050 loss : 0.408326 +[07:54:12.298] iteration 32100 [777.72 sec]: learning rate : 0.000050 loss : 0.331239 +[07:55:43.241] iteration 32200 [868.66 sec]: learning rate : 0.000050 loss : 0.289617 +[07:57:14.139] iteration 32300 [959.56 sec]: learning rate : 0.000050 loss : 0.323061 +[07:58:45.049] iteration 32400 [1050.47 sec]: learning rate : 0.000050 loss : 0.506170 +[08:00:15.966] iteration 32500 [1141.39 sec]: learning rate : 0.000050 loss : 0.433222 +[08:01:46.949] iteration 32600 [1232.37 sec]: learning rate : 0.000050 loss : 0.269466 +[08:03:17.923] iteration 32700 [1323.35 sec]: learning rate : 0.000050 loss : 0.415570 +[08:04:48.806] iteration 32800 [1414.23 sec]: learning rate : 0.000050 loss : 0.415946 +[08:06:19.763] iteration 32900 [1505.19 sec]: learning rate : 0.000050 loss : 0.521256 +[08:07:50.698] iteration 33000 [1596.12 sec]: learning rate : 0.000050 loss : 0.356830 +[08:09:21.598] iteration 33100 [1687.02 sec]: learning rate : 0.000050 loss : 0.573287 +[08:10:52.531] iteration 33200 [1777.95 sec]: learning rate : 0.000050 loss : 0.468607 +[08:12:23.485] iteration 33300 [1868.91 sec]: learning rate : 0.000050 loss : 0.350222 +[08:12:48.910] Epoch 15 Evaluation: +[08:13:39.526] average MSE: 0.039351992309093475 average PSNR: 29.70598401949611 average SSIM: 0.731640907325754 +[08:14:45.244] iteration 33400 [65.65 sec]: learning rate : 0.000050 loss : 0.447399 +[08:16:16.200] iteration 33500 [156.61 sec]: learning rate : 0.000050 loss : 0.436414 +[08:17:47.155] iteration 33600 [247.56 sec]: learning rate : 0.000050 loss : 0.953995 +[08:19:18.037] iteration 33700 [338.44 sec]: learning rate : 0.000050 loss : 0.313564 +[08:20:48.989] iteration 33800 [429.39 sec]: learning rate : 0.000050 loss : 0.314854 +[08:22:19.906] iteration 33900 [520.31 sec]: learning rate : 0.000050 loss : 0.424047 +[08:23:50.852] iteration 34000 [611.26 sec]: learning rate : 0.000050 loss : 0.743204 +[08:25:21.800] iteration 34100 [702.21 sec]: learning rate : 0.000050 loss : 0.569104 +[08:26:52.692] iteration 34200 [793.10 sec]: learning rate : 0.000050 loss : 0.402677 +[08:28:23.676] iteration 34300 [884.08 sec]: learning rate : 0.000050 loss : 0.310482 +[08:29:54.590] iteration 34400 [975.00 sec]: learning rate : 0.000050 loss : 0.484431 +[08:31:25.492] iteration 34500 [1065.90 sec]: learning rate : 0.000050 loss : 0.696667 +[08:32:56.458] iteration 34600 [1156.88 sec]: learning rate : 0.000050 loss : 0.665592 +[08:34:27.453] iteration 34700 [1247.86 sec]: learning rate : 0.000050 loss : 0.482073 +[08:35:58.341] iteration 34800 [1338.75 sec]: learning rate : 0.000050 loss : 0.439087 +[08:37:29.286] iteration 34900 [1429.69 sec]: learning rate : 0.000050 loss : 0.650227 +[08:39:00.201] iteration 35000 [1520.61 sec]: learning rate : 0.000050 loss : 0.336636 +[08:40:31.161] iteration 35100 [1611.57 sec]: learning rate : 0.000050 loss : 0.530378 +[08:42:02.139] iteration 35200 [1702.55 sec]: learning rate : 0.000050 loss : 0.541932 +[08:43:33.045] iteration 35300 [1793.45 sec]: learning rate : 0.000050 loss : 0.445114 +[08:45:04.008] iteration 35400 [1884.41 sec]: learning rate : 0.000050 loss : 0.274761 +[08:45:13.981] Epoch 16 Evaluation: +[08:46:06.625] average MSE: 0.039209142327308655 average PSNR: 29.7204523702007 average SSIM: 0.7316591549401394 +[08:47:27.910] iteration 35500 [81.22 sec]: learning rate : 0.000050 loss : 0.783692 +[08:48:58.795] iteration 35600 [172.10 sec]: learning rate : 0.000050 loss : 0.340930 +[08:50:29.765] iteration 35700 [263.07 sec]: learning rate : 0.000050 loss : 0.861455 +[08:52:00.735] iteration 35800 [354.05 sec]: learning rate : 0.000050 loss : 0.587017 +[08:53:31.624] iteration 35900 [444.93 sec]: learning rate : 0.000050 loss : 0.421441 +[08:55:02.584] iteration 36000 [535.89 sec]: learning rate : 0.000050 loss : 0.562954 +[08:56:33.508] iteration 36100 [626.82 sec]: learning rate : 0.000050 loss : 0.553369 +[08:58:04.473] iteration 36200 [717.79 sec]: learning rate : 0.000050 loss : 0.343218 +[08:59:35.439] iteration 36300 [808.75 sec]: learning rate : 0.000050 loss : 0.461144 +[09:01:06.355] iteration 36400 [899.66 sec]: learning rate : 0.000050 loss : 0.294346 +[09:02:37.367] iteration 36500 [990.68 sec]: learning rate : 0.000050 loss : 0.509693 +[09:04:08.307] iteration 36600 [1081.62 sec]: learning rate : 0.000050 loss : 0.234895 +[09:05:39.202] iteration 36700 [1172.51 sec]: learning rate : 0.000050 loss : 0.469450 +[09:07:10.139] iteration 36800 [1263.45 sec]: learning rate : 0.000050 loss : 0.608326 +[09:08:41.031] iteration 36900 [1354.34 sec]: learning rate : 0.000050 loss : 0.705674 +[09:10:11.924] iteration 37000 [1445.23 sec]: learning rate : 0.000050 loss : 0.408499 +[09:11:42.898] iteration 37100 [1536.21 sec]: learning rate : 0.000050 loss : 0.463214 +[09:13:13.787] iteration 37200 [1627.10 sec]: learning rate : 0.000050 loss : 0.485385 +[09:14:44.741] iteration 37300 [1718.05 sec]: learning rate : 0.000050 loss : 0.693489 +[09:16:15.710] iteration 37400 [1809.02 sec]: learning rate : 0.000050 loss : 0.417496 +[09:17:41.138] Epoch 17 Evaluation: +[09:18:33.981] average MSE: 0.03951997682452202 average PSNR: 29.699423935747415 average SSIM: 0.7329798708211284 +[09:18:39.719] iteration 37500 [5.67 sec]: learning rate : 0.000050 loss : 0.375987 +[09:20:10.746] iteration 37600 [96.70 sec]: learning rate : 0.000050 loss : 0.382628 +[09:21:41.709] iteration 37700 [187.66 sec]: learning rate : 0.000050 loss : 0.710533 +[09:23:12.591] iteration 37800 [278.54 sec]: learning rate : 0.000050 loss : 0.411204 +[09:24:43.523] iteration 37900 [369.47 sec]: learning rate : 0.000050 loss : 0.606239 +[09:26:14.409] iteration 38000 [460.36 sec]: learning rate : 0.000050 loss : 0.719577 +[09:27:45.362] iteration 38100 [551.31 sec]: learning rate : 0.000050 loss : 0.491901 +[09:29:16.290] iteration 38200 [642.24 sec]: learning rate : 0.000050 loss : 0.546211 +[09:30:47.191] iteration 38300 [733.14 sec]: learning rate : 0.000050 loss : 0.430884 +[09:32:18.165] iteration 38400 [824.12 sec]: learning rate : 0.000050 loss : 0.565717 +[09:33:49.153] iteration 38500 [915.10 sec]: learning rate : 0.000050 loss : 0.433116 +[09:35:20.045] iteration 38600 [1006.00 sec]: learning rate : 0.000050 loss : 0.348236 +[09:36:51.031] iteration 38700 [1096.98 sec]: learning rate : 0.000050 loss : 0.329411 +[09:38:21.948] iteration 38800 [1187.90 sec]: learning rate : 0.000050 loss : 0.807906 +[09:39:52.864] iteration 38900 [1278.82 sec]: learning rate : 0.000050 loss : 0.558965 +[09:41:23.821] iteration 39000 [1369.77 sec]: learning rate : 0.000050 loss : 0.417621 +[09:42:54.730] iteration 39100 [1460.68 sec]: learning rate : 0.000050 loss : 0.309460 +[09:44:25.665] iteration 39200 [1551.62 sec]: learning rate : 0.000050 loss : 0.381870 +[09:45:56.632] iteration 39300 [1642.58 sec]: learning rate : 0.000050 loss : 0.604730 +[09:47:27.534] iteration 39400 [1733.49 sec]: learning rate : 0.000050 loss : 0.587239 +[09:48:58.535] iteration 39500 [1824.49 sec]: learning rate : 0.000050 loss : 0.581069 +[09:50:08.492] Epoch 18 Evaluation: +[09:51:00.561] average MSE: 0.03934395685791969 average PSNR: 29.709252995130097 average SSIM: 0.7329235618890452 +[09:51:21.754] iteration 39600 [21.13 sec]: learning rate : 0.000050 loss : 0.629627 +[09:52:52.709] iteration 39700 [112.08 sec]: learning rate : 0.000050 loss : 0.494658 +[09:54:23.664] iteration 39800 [203.04 sec]: learning rate : 0.000050 loss : 0.441649 +[09:55:54.557] iteration 39900 [293.93 sec]: learning rate : 0.000050 loss : 0.420557 +[09:57:25.490] iteration 40000 [384.86 sec]: learning rate : 0.000013 loss : 0.308778 +[09:57:25.681] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_40000.pth +[09:58:56.644] iteration 40100 [476.02 sec]: learning rate : 0.000025 loss : 0.425085 +[10:00:27.526] iteration 40200 [566.90 sec]: learning rate : 0.000025 loss : 0.345385 +[10:01:58.458] iteration 40300 [657.83 sec]: learning rate : 0.000025 loss : 0.280649 +[10:03:29.344] iteration 40400 [748.72 sec]: learning rate : 0.000025 loss : 0.638721 +[10:05:00.279] iteration 40500 [839.65 sec]: learning rate : 0.000025 loss : 0.405575 +[10:06:31.189] iteration 40600 [930.56 sec]: learning rate : 0.000025 loss : 0.487398 +[10:08:02.079] iteration 40700 [1021.45 sec]: learning rate : 0.000025 loss : 0.593658 +[10:09:33.048] iteration 40800 [1112.42 sec]: learning rate : 0.000025 loss : 0.384927 +[10:11:03.998] iteration 40900 [1203.37 sec]: learning rate : 0.000025 loss : 0.472944 +[10:12:34.885] iteration 41000 [1294.26 sec]: learning rate : 0.000025 loss : 0.377569 +[10:14:05.831] iteration 41100 [1385.20 sec]: learning rate : 0.000025 loss : 0.447699 +[10:15:36.753] iteration 41200 [1476.12 sec]: learning rate : 0.000025 loss : 0.507821 +[10:17:07.643] iteration 41300 [1567.01 sec]: learning rate : 0.000025 loss : 0.795641 +[10:18:38.586] iteration 41400 [1657.96 sec]: learning rate : 0.000025 loss : 0.617110 +[10:20:09.484] iteration 41500 [1748.85 sec]: learning rate : 0.000025 loss : 0.516766 +[10:21:40.463] iteration 41600 [1839.83 sec]: learning rate : 0.000025 loss : 0.591413 +[10:22:34.973] Epoch 19 Evaluation: +[10:23:25.647] average MSE: 0.03886766731739044 average PSNR: 29.79109203557967 average SSIM: 0.7338097458058673 +[10:24:02.426] iteration 41700 [36.71 sec]: learning rate : 0.000025 loss : 0.956321 +[10:25:33.310] iteration 41800 [127.60 sec]: learning rate : 0.000025 loss : 0.246606 +[10:27:04.219] iteration 41900 [218.50 sec]: learning rate : 0.000025 loss : 0.352143 +[10:28:35.162] iteration 42000 [309.45 sec]: learning rate : 0.000025 loss : 0.719984 +[10:30:06.055] iteration 42100 [400.34 sec]: learning rate : 0.000025 loss : 0.217130 +[10:31:37.009] iteration 42200 [491.29 sec]: learning rate : 0.000025 loss : 0.712313 +[10:33:07.912] iteration 42300 [582.20 sec]: learning rate : 0.000025 loss : 0.426589 +[10:34:38.858] iteration 42400 [673.14 sec]: learning rate : 0.000025 loss : 0.704593 +[10:36:09.822] iteration 42500 [764.11 sec]: learning rate : 0.000025 loss : 0.472393 +[10:37:40.727] iteration 42600 [855.01 sec]: learning rate : 0.000025 loss : 0.425660 +[10:39:11.680] iteration 42700 [945.97 sec]: learning rate : 0.000025 loss : 0.424377 +[10:40:42.618] iteration 42800 [1036.90 sec]: learning rate : 0.000025 loss : 0.566890 +[10:42:13.507] iteration 42900 [1127.79 sec]: learning rate : 0.000025 loss : 0.910824 +[10:43:44.461] iteration 43000 [1218.75 sec]: learning rate : 0.000025 loss : 0.344847 +[10:45:15.370] iteration 43100 [1309.66 sec]: learning rate : 0.000025 loss : 0.312238 +[10:46:46.256] iteration 43200 [1400.54 sec]: learning rate : 0.000025 loss : 0.745042 +[10:48:17.207] iteration 43300 [1491.49 sec]: learning rate : 0.000025 loss : 0.369976 +[10:49:48.155] iteration 43400 [1582.44 sec]: learning rate : 0.000025 loss : 0.379573 +[10:51:19.047] iteration 43500 [1673.33 sec]: learning rate : 0.000025 loss : 0.436930 +[10:52:50.009] iteration 43600 [1764.29 sec]: learning rate : 0.000025 loss : 0.348795 +[10:54:20.915] iteration 43700 [1855.20 sec]: learning rate : 0.000025 loss : 0.264225 +[10:54:59.993] Epoch 20 Evaluation: +[10:55:52.040] average MSE: 0.038796618580818176 average PSNR: 29.795192055971096 average SSIM: 0.7341194158291878 +[10:56:44.172] iteration 43800 [52.06 sec]: learning rate : 0.000025 loss : 0.272453 +[10:58:15.190] iteration 43900 [143.08 sec]: learning rate : 0.000025 loss : 0.419095 +[10:59:46.074] iteration 44000 [233.97 sec]: learning rate : 0.000025 loss : 1.053208 +[11:01:17.011] iteration 44100 [324.90 sec]: learning rate : 0.000025 loss : 0.520547 +[11:02:47.926] iteration 44200 [415.82 sec]: learning rate : 0.000025 loss : 0.398227 +[11:04:18.818] iteration 44300 [506.71 sec]: learning rate : 0.000025 loss : 0.696428 +[11:05:49.755] iteration 44400 [597.65 sec]: learning rate : 0.000025 loss : 0.546293 +[11:07:20.694] iteration 44500 [688.59 sec]: learning rate : 0.000025 loss : 0.427452 +[11:08:51.577] iteration 44600 [779.47 sec]: learning rate : 0.000025 loss : 0.344344 +[11:10:22.518] iteration 44700 [870.41 sec]: learning rate : 0.000025 loss : 0.606996 +[11:11:53.398] iteration 44800 [961.29 sec]: learning rate : 0.000025 loss : 0.284840 +[11:13:24.383] iteration 44900 [1052.28 sec]: learning rate : 0.000025 loss : 0.665136 +[11:14:55.339] iteration 45000 [1143.23 sec]: learning rate : 0.000025 loss : 0.380719 +[11:16:26.238] iteration 45100 [1234.13 sec]: learning rate : 0.000025 loss : 0.403187 +[11:17:57.159] iteration 45200 [1325.05 sec]: learning rate : 0.000025 loss : 0.527407 +[11:19:28.122] iteration 45300 [1416.01 sec]: learning rate : 0.000025 loss : 0.309441 +[11:20:59.020] iteration 45400 [1506.91 sec]: learning rate : 0.000025 loss : 0.631888 +[11:22:29.989] iteration 45500 [1597.88 sec]: learning rate : 0.000025 loss : 0.545698 +[11:24:00.890] iteration 45600 [1688.78 sec]: learning rate : 0.000025 loss : 0.335917 +[11:25:31.848] iteration 45700 [1779.74 sec]: learning rate : 0.000025 loss : 0.418531 +[11:27:02.812] iteration 45800 [1870.70 sec]: learning rate : 0.000025 loss : 0.260244 +[11:27:26.411] Epoch 21 Evaluation: +[11:28:17.030] average MSE: 0.038770753890275955 average PSNR: 29.800697962318434 average SSIM: 0.734008600705705 +[11:29:24.626] iteration 45900 [67.53 sec]: learning rate : 0.000025 loss : 0.345197 +[11:30:55.619] iteration 46000 [158.52 sec]: learning rate : 0.000025 loss : 0.600761 +[11:32:26.580] iteration 46100 [249.48 sec]: learning rate : 0.000025 loss : 0.453021 +[11:33:57.463] iteration 46200 [340.37 sec]: learning rate : 0.000025 loss : 0.460668 +[11:35:28.396] iteration 46300 [431.30 sec]: learning rate : 0.000025 loss : 0.240308 +[11:36:59.350] iteration 46400 [522.25 sec]: learning rate : 0.000025 loss : 0.678491 +[11:38:30.240] iteration 46500 [613.14 sec]: learning rate : 0.000025 loss : 0.639989 +[11:40:01.171] iteration 46600 [704.08 sec]: learning rate : 0.000025 loss : 0.272039 +[11:41:32.076] iteration 46700 [794.98 sec]: learning rate : 0.000025 loss : 0.407535 +[11:43:03.075] iteration 46800 [885.98 sec]: learning rate : 0.000025 loss : 0.461778 +[11:44:34.036] iteration 46900 [976.94 sec]: learning rate : 0.000025 loss : 0.452315 +[11:46:04.938] iteration 47000 [1067.84 sec]: learning rate : 0.000025 loss : 0.304670 +[11:47:35.904] iteration 47100 [1158.81 sec]: learning rate : 0.000025 loss : 0.527540 +[11:49:06.882] iteration 47200 [1249.78 sec]: learning rate : 0.000025 loss : 0.353523 +[11:50:37.775] iteration 47300 [1340.68 sec]: learning rate : 0.000025 loss : 0.301280 +[11:52:08.737] iteration 47400 [1431.64 sec]: learning rate : 0.000025 loss : 0.640315 +[11:53:39.689] iteration 47500 [1522.59 sec]: learning rate : 0.000025 loss : 0.640754 +[11:55:10.599] iteration 47600 [1613.50 sec]: learning rate : 0.000025 loss : 0.439054 +[11:56:41.585] iteration 47700 [1704.49 sec]: learning rate : 0.000025 loss : 0.382193 +[11:58:12.508] iteration 47800 [1795.41 sec]: learning rate : 0.000025 loss : 0.449257 +[11:59:43.456] iteration 47900 [1886.36 sec]: learning rate : 0.000025 loss : 0.502068 +[11:59:51.608] Epoch 22 Evaluation: +[12:00:44.670] average MSE: 0.03881683200597763 average PSNR: 29.80164975310991 average SSIM: 0.7345940278910916 +[12:02:07.797] iteration 48000 [83.06 sec]: learning rate : 0.000025 loss : 0.804354 +[12:03:38.690] iteration 48100 [173.95 sec]: learning rate : 0.000025 loss : 0.336557 +[12:05:09.629] iteration 48200 [264.89 sec]: learning rate : 0.000025 loss : 0.688208 +[12:06:40.576] iteration 48300 [355.84 sec]: learning rate : 0.000025 loss : 0.637836 +[12:08:11.484] iteration 48400 [446.75 sec]: learning rate : 0.000025 loss : 0.466255 +[12:09:42.469] iteration 48500 [537.73 sec]: learning rate : 0.000025 loss : 0.369207 +[12:11:13.454] iteration 48600 [628.78 sec]: learning rate : 0.000025 loss : 0.376274 +[12:12:44.384] iteration 48700 [719.65 sec]: learning rate : 0.000025 loss : 0.495656 +[12:14:15.337] iteration 48800 [810.60 sec]: learning rate : 0.000025 loss : 0.494332 +[12:15:46.232] iteration 48900 [901.50 sec]: learning rate : 0.000025 loss : 0.640360 +[12:17:17.195] iteration 49000 [992.46 sec]: learning rate : 0.000025 loss : 0.553144 +[12:18:48.156] iteration 49100 [1083.42 sec]: learning rate : 0.000025 loss : 0.297200 +[12:20:19.057] iteration 49200 [1174.32 sec]: learning rate : 0.000025 loss : 0.460308 +[12:21:50.010] iteration 49300 [1265.27 sec]: learning rate : 0.000025 loss : 0.494049 +[12:23:20.969] iteration 49400 [1356.23 sec]: learning rate : 0.000025 loss : 0.584512 +[12:24:51.878] iteration 49500 [1447.14 sec]: learning rate : 0.000025 loss : 0.698458 +[12:26:22.845] iteration 49600 [1538.11 sec]: learning rate : 0.000025 loss : 0.310852 +[12:27:53.827] iteration 49700 [1629.09 sec]: learning rate : 0.000025 loss : 0.699110 +[12:29:24.742] iteration 49800 [1720.01 sec]: learning rate : 0.000025 loss : 0.711478 +[12:30:55.694] iteration 49900 [1810.96 sec]: learning rate : 0.000025 loss : 0.699729 +[12:32:19.304] Epoch 23 Evaluation: +[12:33:10.134] average MSE: 0.03869752958416939 average PSNR: 29.814509018779873 average SSIM: 0.7349999334529023 +[12:33:17.828] iteration 50000 [7.63 sec]: learning rate : 0.000025 loss : 0.322018 +[12:34:48.747] iteration 50100 [98.54 sec]: learning rate : 0.000025 loss : 0.330535 +[12:36:19.708] iteration 50200 [189.51 sec]: learning rate : 0.000025 loss : 0.601770 +[12:37:50.607] iteration 50300 [280.40 sec]: learning rate : 0.000025 loss : 0.482253 +[12:39:21.533] iteration 50400 [371.33 sec]: learning rate : 0.000025 loss : 0.672355 +[12:40:52.473] iteration 50500 [462.27 sec]: learning rate : 0.000025 loss : 0.555429 +[12:42:23.378] iteration 50600 [553.18 sec]: learning rate : 0.000025 loss : 0.354191 +[12:43:54.364] iteration 50700 [644.16 sec]: learning rate : 0.000025 loss : 0.492750 +[12:45:25.335] iteration 50800 [735.13 sec]: learning rate : 0.000025 loss : 0.423676 +[12:46:56.240] iteration 50900 [826.04 sec]: learning rate : 0.000025 loss : 0.450236 +[12:48:27.187] iteration 51000 [916.98 sec]: learning rate : 0.000025 loss : 0.444618 +[12:49:58.083] iteration 51100 [1007.88 sec]: learning rate : 0.000025 loss : 0.454842 +[12:51:29.042] iteration 51200 [1098.84 sec]: learning rate : 0.000025 loss : 0.352529 +[12:52:59.992] iteration 51300 [1189.79 sec]: learning rate : 0.000025 loss : 0.606555 +[12:54:30.906] iteration 51400 [1280.70 sec]: learning rate : 0.000025 loss : 0.672483 +[12:56:01.863] iteration 51500 [1371.66 sec]: learning rate : 0.000025 loss : 0.372294 +[12:57:32.824] iteration 51600 [1462.62 sec]: learning rate : 0.000025 loss : 0.568868 +[12:59:03.739] iteration 51700 [1553.54 sec]: learning rate : 0.000025 loss : 0.380025 +[13:00:34.713] iteration 51800 [1644.51 sec]: learning rate : 0.000025 loss : 0.710488 +[13:02:05.719] iteration 51900 [1735.52 sec]: learning rate : 0.000025 loss : 0.626722 +[13:03:36.617] iteration 52000 [1826.41 sec]: learning rate : 0.000025 loss : 0.492970 +[13:04:44.823] Epoch 24 Evaluation: +[13:05:34.975] average MSE: 0.03919699788093567 average PSNR: 29.761772746851054 average SSIM: 0.7340487836808817 +[13:05:57.994] iteration 52100 [22.95 sec]: learning rate : 0.000025 loss : 0.411576 +[13:07:28.909] iteration 52200 [113.87 sec]: learning rate : 0.000025 loss : 0.440630 +[13:08:59.911] iteration 52300 [204.87 sec]: learning rate : 0.000025 loss : 0.347562 +[13:10:30.885] iteration 52400 [295.84 sec]: learning rate : 0.000025 loss : 0.182353 +[13:12:01.824] iteration 52500 [386.78 sec]: learning rate : 0.000025 loss : 0.420486 +[13:13:32.778] iteration 52600 [477.73 sec]: learning rate : 0.000025 loss : 0.443911 +[13:15:03.746] iteration 52700 [568.70 sec]: learning rate : 0.000025 loss : 0.748213 +[13:16:34.640] iteration 52800 [659.60 sec]: learning rate : 0.000025 loss : 0.408913 +[13:18:05.563] iteration 52900 [750.52 sec]: learning rate : 0.000025 loss : 0.402352 +[13:19:36.478] iteration 53000 [841.43 sec]: learning rate : 0.000025 loss : 0.415690 +[13:21:07.474] iteration 53100 [932.43 sec]: learning rate : 0.000025 loss : 0.734076 +[13:22:38.461] iteration 53200 [1023.42 sec]: learning rate : 0.000025 loss : 0.371733 +[13:24:09.378] iteration 53300 [1114.33 sec]: learning rate : 0.000025 loss : 0.723991 +[13:25:40.369] iteration 53400 [1205.33 sec]: learning rate : 0.000025 loss : 0.336149 +[13:27:11.352] iteration 53500 [1296.31 sec]: learning rate : 0.000025 loss : 0.645842 +[13:28:42.251] iteration 53600 [1387.21 sec]: learning rate : 0.000025 loss : 0.354741 +[13:30:13.228] iteration 53700 [1478.19 sec]: learning rate : 0.000025 loss : 0.208902 +[13:31:44.183] iteration 53800 [1569.14 sec]: learning rate : 0.000025 loss : 0.612292 +[13:33:15.102] iteration 53900 [1660.06 sec]: learning rate : 0.000025 loss : 0.468646 +[13:34:46.079] iteration 54000 [1751.04 sec]: learning rate : 0.000025 loss : 0.477083 +[13:36:17.055] iteration 54100 [1842.01 sec]: learning rate : 0.000025 loss : 0.250727 +[13:37:09.763] Epoch 25 Evaluation: +[13:38:00.406] average MSE: 0.03858011215925217 average PSNR: 29.822689387439265 average SSIM: 0.7349740888767607 +[13:38:38.904] iteration 54200 [38.43 sec]: learning rate : 0.000025 loss : 0.420236 +[13:40:09.922] iteration 54300 [129.45 sec]: learning rate : 0.000025 loss : 0.286559 +[13:41:40.843] iteration 54400 [220.37 sec]: learning rate : 0.000025 loss : 0.281678 +[13:43:11.776] iteration 54500 [311.30 sec]: learning rate : 0.000025 loss : 0.554089 +[13:44:42.687] iteration 54600 [402.21 sec]: learning rate : 0.000025 loss : 0.456426 +[13:46:13.587] iteration 54700 [493.11 sec]: learning rate : 0.000025 loss : 0.558625 +[13:47:44.549] iteration 54800 [584.07 sec]: learning rate : 0.000025 loss : 0.663672 +[13:49:15.457] iteration 54900 [674.98 sec]: learning rate : 0.000025 loss : 0.571642 +[13:50:46.352] iteration 55000 [765.88 sec]: learning rate : 0.000025 loss : 0.491728 +[13:52:17.301] iteration 55100 [856.83 sec]: learning rate : 0.000025 loss : 0.418976 +[13:53:48.191] iteration 55200 [947.72 sec]: learning rate : 0.000025 loss : 0.367821 +[13:55:19.141] iteration 55300 [1038.67 sec]: learning rate : 0.000025 loss : 0.545848 +[13:56:50.095] iteration 55400 [1129.62 sec]: learning rate : 0.000025 loss : 0.507836 +[13:58:21.008] iteration 55500 [1220.53 sec]: learning rate : 0.000025 loss : 0.372452 +[13:59:51.986] iteration 55600 [1311.51 sec]: learning rate : 0.000025 loss : 0.308942 +[14:01:22.904] iteration 55700 [1402.43 sec]: learning rate : 0.000025 loss : 0.497906 +[14:02:53.809] iteration 55800 [1493.33 sec]: learning rate : 0.000025 loss : 0.375975 +[14:04:24.787] iteration 55900 [1584.31 sec]: learning rate : 0.000025 loss : 0.513326 +[14:05:55.768] iteration 56000 [1675.30 sec]: learning rate : 0.000025 loss : 0.679341 +[14:07:26.676] iteration 56100 [1766.20 sec]: learning rate : 0.000025 loss : 0.647482 +[14:08:57.616] iteration 56200 [1857.14 sec]: learning rate : 0.000025 loss : 0.548312 +[14:09:34.856] Epoch 26 Evaluation: +[14:10:26.247] average MSE: 0.038803089410066605 average PSNR: 29.809697124320262 average SSIM: 0.7353536848433183 +[14:11:20.265] iteration 56300 [53.95 sec]: learning rate : 0.000025 loss : 0.399411 +[14:12:51.166] iteration 56400 [144.85 sec]: learning rate : 0.000025 loss : 0.467570 +[14:14:22.112] iteration 56500 [235.80 sec]: learning rate : 0.000025 loss : 0.581007 +[14:15:52.997] iteration 56600 [326.68 sec]: learning rate : 0.000025 loss : 0.404399 +[14:17:23.960] iteration 56700 [417.65 sec]: learning rate : 0.000025 loss : 0.470836 +[14:18:54.908] iteration 56800 [508.60 sec]: learning rate : 0.000025 loss : 0.489852 +[14:20:25.821] iteration 56900 [599.51 sec]: learning rate : 0.000025 loss : 0.628283 +[14:21:56.739] iteration 57000 [690.42 sec]: learning rate : 0.000025 loss : 0.407427 +[14:23:27.692] iteration 57100 [781.38 sec]: learning rate : 0.000025 loss : 0.466201 +[14:24:58.605] iteration 57200 [872.29 sec]: learning rate : 0.000025 loss : 0.656813 +[14:26:29.550] iteration 57300 [963.24 sec]: learning rate : 0.000025 loss : 0.537057 +[14:28:00.441] iteration 57400 [1054.13 sec]: learning rate : 0.000025 loss : 0.583988 +[14:29:31.418] iteration 57500 [1145.10 sec]: learning rate : 0.000025 loss : 0.509700 +[14:31:02.402] iteration 57600 [1236.09 sec]: learning rate : 0.000025 loss : 0.257904 +[14:32:33.294] iteration 57700 [1326.98 sec]: learning rate : 0.000025 loss : 0.399173 +[14:34:04.237] iteration 57800 [1417.92 sec]: learning rate : 0.000025 loss : 0.481033 +[14:35:35.188] iteration 57900 [1508.87 sec]: learning rate : 0.000025 loss : 0.694433 +[14:37:06.078] iteration 58000 [1599.76 sec]: learning rate : 0.000025 loss : 0.404572 +[14:38:37.039] iteration 58100 [1690.72 sec]: learning rate : 0.000025 loss : 0.478599 +[14:40:07.987] iteration 58200 [1781.67 sec]: learning rate : 0.000025 loss : 0.538219 +[14:41:38.882] iteration 58300 [1872.57 sec]: learning rate : 0.000025 loss : 0.296243 +[14:42:00.676] Epoch 27 Evaluation: +[14:42:51.166] average MSE: 0.03849368914961815 average PSNR: 29.844574548082573 average SSIM: 0.735617745660381 +[14:44:00.590] iteration 58400 [69.36 sec]: learning rate : 0.000025 loss : 0.458992 +[14:45:31.494] iteration 58500 [160.26 sec]: learning rate : 0.000025 loss : 0.467974 +[14:47:02.445] iteration 58600 [251.21 sec]: learning rate : 0.000025 loss : 0.527686 +[14:48:33.386] iteration 58700 [342.15 sec]: learning rate : 0.000025 loss : 0.352629 +[14:50:04.275] iteration 58800 [433.04 sec]: learning rate : 0.000025 loss : 0.200043 +[14:51:35.215] iteration 58900 [523.98 sec]: learning rate : 0.000025 loss : 0.542312 +[14:53:06.201] iteration 59000 [614.97 sec]: learning rate : 0.000025 loss : 1.007141 +[14:54:37.113] iteration 59100 [705.88 sec]: learning rate : 0.000025 loss : 0.619384 +[14:56:08.057] iteration 59200 [796.82 sec]: learning rate : 0.000025 loss : 0.462536 +[14:57:39.020] iteration 59300 [887.79 sec]: learning rate : 0.000025 loss : 0.481699 +[14:59:09.929] iteration 59400 [978.70 sec]: learning rate : 0.000025 loss : 0.467209 +[15:00:40.893] iteration 59500 [1069.66 sec]: learning rate : 0.000025 loss : 0.353733 +[15:02:11.791] iteration 59600 [1160.56 sec]: learning rate : 0.000025 loss : 0.392490 +[15:03:42.734] iteration 59700 [1251.50 sec]: learning rate : 0.000025 loss : 0.397195 +[15:05:13.688] iteration 59800 [1342.45 sec]: learning rate : 0.000025 loss : 0.392662 +[15:06:44.600] iteration 59900 [1433.37 sec]: learning rate : 0.000025 loss : 0.454919 +[15:08:15.524] iteration 60000 [1524.29 sec]: learning rate : 0.000006 loss : 0.311906 +[15:08:15.683] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_60000.pth +[15:09:46.652] iteration 60100 [1615.42 sec]: learning rate : 0.000013 loss : 0.352374 +[15:11:17.572] iteration 60200 [1706.34 sec]: learning rate : 0.000013 loss : 0.519971 +[15:12:48.543] iteration 60300 [1797.31 sec]: learning rate : 0.000013 loss : 0.269691 +[15:14:19.522] iteration 60400 [1888.29 sec]: learning rate : 0.000013 loss : 0.484294 +[15:14:25.853] Epoch 28 Evaluation: +[15:15:16.281] average MSE: 0.038553956896066666 average PSNR: 29.834543387119908 average SSIM: 0.735095751992168 +[15:16:41.115] iteration 60500 [84.77 sec]: learning rate : 0.000013 loss : 0.378085 +[15:18:12.132] iteration 60600 [175.78 sec]: learning rate : 0.000013 loss : 0.533959 +[15:19:43.036] iteration 60700 [266.69 sec]: learning rate : 0.000013 loss : 0.405302 +[15:21:13.996] iteration 60800 [357.65 sec]: learning rate : 0.000013 loss : 0.500040 +[15:22:44.965] iteration 60900 [448.62 sec]: learning rate : 0.000013 loss : 0.558223 +[15:24:15.900] iteration 61000 [539.55 sec]: learning rate : 0.000013 loss : 0.449531 +[15:25:46.896] iteration 61100 [630.55 sec]: learning rate : 0.000013 loss : 0.407114 +[15:27:17.828] iteration 61200 [721.48 sec]: learning rate : 0.000013 loss : 0.450242 +[15:28:48.732] iteration 61300 [812.38 sec]: learning rate : 0.000013 loss : 0.569439 +[15:30:19.702] iteration 61400 [903.35 sec]: learning rate : 0.000013 loss : 0.664510 +[15:31:50.618] iteration 61500 [994.27 sec]: learning rate : 0.000013 loss : 0.610242 +[15:33:21.555] iteration 61600 [1085.21 sec]: learning rate : 0.000013 loss : 0.675329 +[15:34:52.553] iteration 61700 [1176.20 sec]: learning rate : 0.000013 loss : 0.430708 +[15:36:23.500] iteration 61800 [1267.15 sec]: learning rate : 0.000013 loss : 0.388616 +[15:37:54.458] iteration 61900 [1358.11 sec]: learning rate : 0.000013 loss : 0.331367 +[15:39:25.423] iteration 62000 [1449.07 sec]: learning rate : 0.000013 loss : 0.474787 +[15:40:56.342] iteration 62100 [1539.99 sec]: learning rate : 0.000013 loss : 0.358514 +[15:42:27.303] iteration 62200 [1630.95 sec]: learning rate : 0.000013 loss : 0.473254 +[15:43:58.210] iteration 62300 [1721.86 sec]: learning rate : 0.000013 loss : 0.287897 +[15:45:29.150] iteration 62400 [1812.80 sec]: learning rate : 0.000013 loss : 0.273100 +[15:46:50.990] Epoch 29 Evaluation: +[15:47:41.710] average MSE: 0.03837694600224495 average PSNR: 29.86154978981728 average SSIM: 0.7356089007175978 +[15:47:51.086] iteration 62500 [9.31 sec]: learning rate : 0.000013 loss : 0.326566 +[15:49:22.017] iteration 62600 [100.24 sec]: learning rate : 0.000013 loss : 0.562395 +[15:50:53.026] iteration 62700 [191.25 sec]: learning rate : 0.000013 loss : 0.571950 +[15:52:23.984] iteration 62800 [282.21 sec]: learning rate : 0.000013 loss : 0.889067 +[15:53:54.875] iteration 62900 [373.10 sec]: learning rate : 0.000013 loss : 0.307693 +[15:55:25.834] iteration 63000 [464.06 sec]: learning rate : 0.000013 loss : 0.465623 +[15:56:56.757] iteration 63100 [554.98 sec]: learning rate : 0.000013 loss : 0.544728 +[15:58:27.662] iteration 63200 [645.88 sec]: learning rate : 0.000013 loss : 0.381824 +[15:59:58.590] iteration 63300 [736.81 sec]: learning rate : 0.000013 loss : 0.624892 +[16:01:29.566] iteration 63400 [827.79 sec]: learning rate : 0.000013 loss : 0.547227 +[16:03:00.466] iteration 63500 [918.69 sec]: learning rate : 0.000013 loss : 0.235375 +[16:04:31.415] iteration 63600 [1009.64 sec]: learning rate : 0.000013 loss : 0.856354 +[16:06:02.295] iteration 63700 [1100.52 sec]: learning rate : 0.000013 loss : 0.265848 +[16:07:33.242] iteration 63800 [1191.46 sec]: learning rate : 0.000013 loss : 0.565809 +[16:09:04.181] iteration 63900 [1282.40 sec]: learning rate : 0.000013 loss : 0.584474 +[16:10:35.077] iteration 64000 [1373.30 sec]: learning rate : 0.000013 loss : 0.201474 +[16:12:06.026] iteration 64100 [1464.25 sec]: learning rate : 0.000013 loss : 0.609563 +[16:13:36.999] iteration 64200 [1555.22 sec]: learning rate : 0.000013 loss : 0.652952 +[16:15:07.908] iteration 64300 [1646.13 sec]: learning rate : 0.000013 loss : 0.476678 +[16:16:38.860] iteration 64400 [1737.08 sec]: learning rate : 0.000013 loss : 0.424076 +[16:18:09.811] iteration 64500 [1828.03 sec]: learning rate : 0.000013 loss : 0.487011 +[16:19:16.147] Epoch 30 Evaluation: +[16:20:06.782] average MSE: 0.03852885589003563 average PSNR: 29.83714433224949 average SSIM: 0.7356717223911535 +[16:20:31.612] iteration 64600 [24.76 sec]: learning rate : 0.000013 loss : 0.386876 +[16:22:02.620] iteration 64700 [115.77 sec]: learning rate : 0.000013 loss : 0.355423 +[16:23:33.580] iteration 64800 [206.73 sec]: learning rate : 0.000013 loss : 0.504399 +[16:25:04.484] iteration 64900 [297.63 sec]: learning rate : 0.000013 loss : 0.242264 +[16:26:35.441] iteration 65000 [388.59 sec]: learning rate : 0.000013 loss : 0.581064 +[16:28:06.334] iteration 65100 [479.48 sec]: learning rate : 0.000013 loss : 0.702034 +[16:29:37.318] iteration 65200 [570.47 sec]: learning rate : 0.000013 loss : 0.478523 +[16:31:08.291] iteration 65300 [661.44 sec]: learning rate : 0.000013 loss : 0.353277 +[16:32:39.198] iteration 65400 [752.35 sec]: learning rate : 0.000013 loss : 0.559404 +[16:34:10.164] iteration 65500 [843.32 sec]: learning rate : 0.000013 loss : 0.287731 +[16:35:41.155] iteration 65600 [934.31 sec]: learning rate : 0.000013 loss : 0.528510 +[16:37:12.079] iteration 65700 [1025.24 sec]: learning rate : 0.000013 loss : 0.585292 +[16:38:43.035] iteration 65800 [1116.18 sec]: learning rate : 0.000013 loss : 0.569790 +[16:40:13.993] iteration 65900 [1207.14 sec]: learning rate : 0.000013 loss : 0.349384 +[16:41:44.898] iteration 66000 [1298.05 sec]: learning rate : 0.000013 loss : 0.593354 +[16:43:15.856] iteration 66100 [1389.01 sec]: learning rate : 0.000013 loss : 0.430358 +[16:44:46.779] iteration 66200 [1479.93 sec]: learning rate : 0.000013 loss : 0.262511 +[16:46:17.836] iteration 66300 [1570.99 sec]: learning rate : 0.000013 loss : 0.537725 +[16:47:48.769] iteration 66400 [1661.92 sec]: learning rate : 0.000013 loss : 0.782557 +[16:49:19.669] iteration 66500 [1752.82 sec]: learning rate : 0.000013 loss : 0.595512 +[16:50:50.631] iteration 66600 [1843.78 sec]: learning rate : 0.000013 loss : 0.376501 +[16:51:41.519] Epoch 31 Evaluation: +[16:52:31.933] average MSE: 0.03854268789291382 average PSNR: 29.842509923542327 average SSIM: 0.7357836640326756 +[16:53:12.333] iteration 66700 [40.33 sec]: learning rate : 0.000013 loss : 0.286695 +[16:54:43.227] iteration 66800 [131.23 sec]: learning rate : 0.000013 loss : 0.238853 +[16:56:14.174] iteration 66900 [222.17 sec]: learning rate : 0.000013 loss : 0.424680 +[16:57:45.124] iteration 67000 [313.12 sec]: learning rate : 0.000013 loss : 0.385469 +[16:59:16.022] iteration 67100 [404.02 sec]: learning rate : 0.000013 loss : 0.588890 +[17:00:46.965] iteration 67200 [494.97 sec]: learning rate : 0.000013 loss : 0.618929 +[17:02:17.873] iteration 67300 [585.87 sec]: learning rate : 0.000013 loss : 0.324511 +[17:03:48.827] iteration 67400 [676.83 sec]: learning rate : 0.000013 loss : 0.708345 +[17:05:19.779] iteration 67500 [767.78 sec]: learning rate : 0.000013 loss : 0.392757 +[17:06:50.677] iteration 67600 [858.68 sec]: learning rate : 0.000013 loss : 0.536189 +[17:08:21.662] iteration 67700 [949.66 sec]: learning rate : 0.000013 loss : 0.420836 +[17:09:52.621] iteration 67800 [1040.62 sec]: learning rate : 0.000013 loss : 0.388192 +[17:11:23.516] iteration 67900 [1131.51 sec]: learning rate : 0.000013 loss : 0.674182 +[17:12:54.483] iteration 68000 [1222.48 sec]: learning rate : 0.000013 loss : 0.380920 +[17:14:25.386] iteration 68100 [1313.38 sec]: learning rate : 0.000013 loss : 0.214347 +[17:15:56.334] iteration 68200 [1404.33 sec]: learning rate : 0.000013 loss : 0.267809 +[17:17:27.292] iteration 68300 [1495.29 sec]: learning rate : 0.000013 loss : 0.296249 +[17:18:58.187] iteration 68400 [1586.19 sec]: learning rate : 0.000013 loss : 0.506998 +[17:20:29.150] iteration 68500 [1677.15 sec]: learning rate : 0.000013 loss : 0.546443 +[17:22:00.142] iteration 68600 [1768.14 sec]: learning rate : 0.000013 loss : 0.866572 +[17:23:31.060] iteration 68700 [1859.06 sec]: learning rate : 0.000013 loss : 0.523857 +[17:24:06.491] Epoch 32 Evaluation: +[17:24:57.394] average MSE: 0.03848186135292053 average PSNR: 29.849612507057383 average SSIM: 0.7362353456526353 +[17:25:53.256] iteration 68800 [55.79 sec]: learning rate : 0.000013 loss : 0.366609 +[17:27:24.204] iteration 68900 [146.74 sec]: learning rate : 0.000013 loss : 0.330355 +[17:28:55.129] iteration 69000 [237.67 sec]: learning rate : 0.000013 loss : 0.363407 +[17:30:26.096] iteration 69100 [328.64 sec]: learning rate : 0.000013 loss : 0.613000 +[17:31:57.007] iteration 69200 [419.54 sec]: learning rate : 0.000013 loss : 0.684758 +[17:33:27.992] iteration 69300 [510.53 sec]: learning rate : 0.000013 loss : 0.560658 +[17:34:58.981] iteration 69400 [601.52 sec]: learning rate : 0.000013 loss : 0.538495 +[17:36:29.905] iteration 69500 [692.44 sec]: learning rate : 0.000013 loss : 0.540323 +[17:38:00.875] iteration 69600 [783.41 sec]: learning rate : 0.000013 loss : 0.596963 +[17:39:31.856] iteration 69700 [874.39 sec]: learning rate : 0.000013 loss : 0.591516 +[17:41:02.757] iteration 69800 [965.30 sec]: learning rate : 0.000013 loss : 0.457718 +[17:42:33.692] iteration 69900 [1056.23 sec]: learning rate : 0.000013 loss : 0.408494 +[17:44:04.665] iteration 70000 [1147.20 sec]: learning rate : 0.000013 loss : 0.539344 +[17:45:35.581] iteration 70100 [1238.12 sec]: learning rate : 0.000013 loss : 0.530718 +[17:47:06.549] iteration 70200 [1329.09 sec]: learning rate : 0.000013 loss : 0.467009 +[17:48:37.527] iteration 70300 [1420.07 sec]: learning rate : 0.000013 loss : 0.449466 +[17:50:08.433] iteration 70400 [1510.97 sec]: learning rate : 0.000013 loss : 0.588262 +[17:51:39.342] iteration 70500 [1601.88 sec]: learning rate : 0.000013 loss : 0.366243 +[17:53:10.244] iteration 70600 [1692.78 sec]: learning rate : 0.000013 loss : 0.297096 +[17:54:41.198] iteration 70700 [1783.74 sec]: learning rate : 0.000013 loss : 0.757296 +[17:56:12.163] iteration 70800 [1874.70 sec]: learning rate : 0.000013 loss : 0.491323 +[17:56:32.145] Epoch 33 Evaluation: +[17:57:22.647] average MSE: 0.038502659648656845 average PSNR: 29.84599803849133 average SSIM: 0.7366494081963287 +[17:58:33.836] iteration 70900 [71.12 sec]: learning rate : 0.000013 loss : 0.430810 +[18:00:04.802] iteration 71000 [162.09 sec]: learning rate : 0.000013 loss : 0.509532 +[18:01:35.789] iteration 71100 [253.08 sec]: learning rate : 0.000013 loss : 0.539509 +[18:03:06.703] iteration 71200 [343.99 sec]: learning rate : 0.000013 loss : 0.329357 +[18:04:37.672] iteration 71300 [434.96 sec]: learning rate : 0.000013 loss : 0.659350 +[18:06:08.576] iteration 71400 [525.86 sec]: learning rate : 0.000013 loss : 0.348541 +[18:07:39.548] iteration 71500 [616.84 sec]: learning rate : 0.000013 loss : 0.720753 +[18:09:10.489] iteration 71600 [707.78 sec]: learning rate : 0.000013 loss : 0.549567 +[18:10:41.420] iteration 71700 [798.71 sec]: learning rate : 0.000013 loss : 0.509491 +[18:12:12.415] iteration 71800 [889.70 sec]: learning rate : 0.000013 loss : 0.489419 +[18:13:43.351] iteration 71900 [980.64 sec]: learning rate : 0.000013 loss : 0.699134 +[18:15:14.257] iteration 72000 [1071.54 sec]: learning rate : 0.000013 loss : 0.453773 +[18:16:45.210] iteration 72100 [1162.50 sec]: learning rate : 0.000013 loss : 0.794755 +[18:18:16.126] iteration 72200 [1253.41 sec]: learning rate : 0.000013 loss : 0.302222 +[18:19:47.072] iteration 72300 [1344.36 sec]: learning rate : 0.000013 loss : 0.421455 +[18:21:18.072] iteration 72400 [1435.36 sec]: learning rate : 0.000013 loss : 0.720749 +[18:22:48.987] iteration 72500 [1526.27 sec]: learning rate : 0.000013 loss : 0.354838 +[18:24:19.923] iteration 72600 [1617.21 sec]: learning rate : 0.000013 loss : 0.344319 +[18:25:50.899] iteration 72700 [1708.18 sec]: learning rate : 0.000013 loss : 0.579001 +[18:27:21.801] iteration 72800 [1799.09 sec]: learning rate : 0.000013 loss : 0.306849 +[18:28:52.774] iteration 72900 [1890.06 sec]: learning rate : 0.000013 loss : 0.364474 +[18:28:57.290] Epoch 34 Evaluation: +[18:29:47.658] average MSE: 0.038440655916929245 average PSNR: 29.852505734029375 average SSIM: 0.7359570521925783 +[18:31:14.302] iteration 73000 [86.58 sec]: learning rate : 0.000013 loss : 0.384775 +[18:32:45.349] iteration 73100 [177.62 sec]: learning rate : 0.000013 loss : 1.011246 +[18:34:16.349] iteration 73200 [268.62 sec]: learning rate : 0.000013 loss : 0.429093 +[18:35:47.288] iteration 73300 [359.56 sec]: learning rate : 0.000013 loss : 0.601889 +[18:37:18.258] iteration 73400 [450.53 sec]: learning rate : 0.000013 loss : 0.464064 +[18:38:49.226] iteration 73500 [541.50 sec]: learning rate : 0.000013 loss : 0.282627 +[18:40:20.136] iteration 73600 [632.41 sec]: learning rate : 0.000013 loss : 0.399679 +[18:41:51.101] iteration 73700 [723.37 sec]: learning rate : 0.000013 loss : 0.509381 +[18:43:22.067] iteration 73800 [814.34 sec]: learning rate : 0.000013 loss : 0.734387 +[18:44:52.976] iteration 73900 [905.25 sec]: learning rate : 0.000013 loss : 0.673438 +[18:46:23.925] iteration 74000 [996.20 sec]: learning rate : 0.000013 loss : 0.691100 +[18:47:54.824] iteration 74100 [1087.10 sec]: learning rate : 0.000013 loss : 0.401930 +[18:49:25.780] iteration 74200 [1178.05 sec]: learning rate : 0.000013 loss : 0.620544 +[18:50:56.705] iteration 74300 [1268.98 sec]: learning rate : 0.000013 loss : 0.462793 +[18:52:27.604] iteration 74400 [1359.88 sec]: learning rate : 0.000013 loss : 0.362225 +[18:53:58.574] iteration 74500 [1450.85 sec]: learning rate : 0.000013 loss : 0.469396 +[18:55:29.527] iteration 74600 [1541.80 sec]: learning rate : 0.000013 loss : 0.340389 +[18:57:00.421] iteration 74700 [1632.70 sec]: learning rate : 0.000013 loss : 0.599393 +[18:58:31.378] iteration 74800 [1723.65 sec]: learning rate : 0.000013 loss : 0.555480 +[19:00:02.292] iteration 74900 [1814.57 sec]: learning rate : 0.000013 loss : 0.307498 +[19:01:22.337] Epoch 35 Evaluation: +[19:02:13.109] average MSE: 0.03834831714630127 average PSNR: 29.866220540931394 average SSIM: 0.7363410520796619 +[19:02:24.322] iteration 75000 [11.15 sec]: learning rate : 0.000013 loss : 0.162321 +[19:03:55.348] iteration 75100 [102.17 sec]: learning rate : 0.000013 loss : 0.482795 +[19:05:26.261] iteration 75200 [193.09 sec]: learning rate : 0.000013 loss : 0.381796 +[19:06:57.245] iteration 75300 [284.07 sec]: learning rate : 0.000013 loss : 0.597930 +[19:08:28.227] iteration 75400 [375.05 sec]: learning rate : 0.000013 loss : 0.521937 +[19:09:59.146] iteration 75500 [465.97 sec]: learning rate : 0.000013 loss : 0.528476 +[19:11:30.126] iteration 75600 [556.95 sec]: learning rate : 0.000013 loss : 0.662810 +[19:13:01.096] iteration 75700 [647.92 sec]: learning rate : 0.000013 loss : 0.782105 +[19:14:31.989] iteration 75800 [738.81 sec]: learning rate : 0.000013 loss : 0.593421 +[19:16:02.956] iteration 75900 [829.78 sec]: learning rate : 0.000013 loss : 0.599540 +[19:17:33.937] iteration 76000 [920.76 sec]: learning rate : 0.000013 loss : 0.206365 +[19:19:04.838] iteration 76100 [1011.66 sec]: learning rate : 0.000013 loss : 0.477476 +[19:20:35.755] iteration 76200 [1102.58 sec]: learning rate : 0.000013 loss : 0.504556 +[19:22:06.667] iteration 76300 [1193.49 sec]: learning rate : 0.000013 loss : 0.478698 +[19:23:37.638] iteration 76400 [1284.46 sec]: learning rate : 0.000013 loss : 0.421651 +[19:25:08.564] iteration 76500 [1375.39 sec]: learning rate : 0.000013 loss : 0.359423 +[19:26:39.469] iteration 76600 [1466.29 sec]: learning rate : 0.000013 loss : 0.535767 +[19:28:10.449] iteration 76700 [1557.28 sec]: learning rate : 0.000013 loss : 0.730385 +[19:29:41.446] iteration 76800 [1648.27 sec]: learning rate : 0.000013 loss : 0.272582 +[19:31:12.359] iteration 76900 [1739.18 sec]: learning rate : 0.000013 loss : 0.630837 +[19:32:43.277] iteration 77000 [1830.10 sec]: learning rate : 0.000013 loss : 0.659374 +[19:33:47.797] Epoch 36 Evaluation: +[19:34:38.349] average MSE: 0.03835471719503403 average PSNR: 29.871964688166624 average SSIM: 0.7366090735238014 +[19:35:05.165] iteration 77100 [26.75 sec]: learning rate : 0.000013 loss : 0.400843 +[19:36:36.069] iteration 77200 [117.65 sec]: learning rate : 0.000013 loss : 0.333202 +[19:38:07.060] iteration 77300 [208.64 sec]: learning rate : 0.000013 loss : 0.464431 +[19:39:37.971] iteration 77400 [299.55 sec]: learning rate : 0.000013 loss : 0.478782 +[19:41:08.947] iteration 77500 [390.53 sec]: learning rate : 0.000013 loss : 0.493268 +[19:42:39.933] iteration 77600 [481.52 sec]: learning rate : 0.000013 loss : 0.794724 +[19:44:10.856] iteration 77700 [572.44 sec]: learning rate : 0.000013 loss : 0.615866 +[19:45:41.863] iteration 77800 [663.45 sec]: learning rate : 0.000013 loss : 0.454923 +[19:47:12.833] iteration 77900 [754.42 sec]: learning rate : 0.000013 loss : 0.449040 +[19:48:43.740] iteration 78000 [845.32 sec]: learning rate : 0.000013 loss : 0.441693 +[19:50:14.666] iteration 78100 [936.25 sec]: learning rate : 0.000013 loss : 0.466369 +[19:51:45.577] iteration 78200 [1027.16 sec]: learning rate : 0.000013 loss : 0.589226 +[19:53:16.560] iteration 78300 [1118.14 sec]: learning rate : 0.000013 loss : 0.416112 +[19:54:47.522] iteration 78400 [1209.11 sec]: learning rate : 0.000013 loss : 0.375351 +[19:56:18.419] iteration 78500 [1300.00 sec]: learning rate : 0.000013 loss : 0.513987 +[19:57:49.414] iteration 78600 [1391.00 sec]: learning rate : 0.000013 loss : 0.561443 +[19:59:20.363] iteration 78700 [1481.95 sec]: learning rate : 0.000013 loss : 0.384652 +[20:00:51.277] iteration 78800 [1572.86 sec]: learning rate : 0.000013 loss : 0.437442 +[20:02:22.231] iteration 78900 [1663.81 sec]: learning rate : 0.000013 loss : 0.717858 +[20:03:53.132] iteration 79000 [1754.71 sec]: learning rate : 0.000013 loss : 0.567289 +[20:05:24.085] iteration 79100 [1845.67 sec]: learning rate : 0.000013 loss : 0.712598 +[20:06:13.143] Epoch 37 Evaluation: +[20:07:03.623] average MSE: 0.03836868330836296 average PSNR: 29.86458634367572 average SSIM: 0.7365880512329909 +[20:07:45.842] iteration 79200 [42.15 sec]: learning rate : 0.000013 loss : 0.686980 +[20:09:16.747] iteration 79300 [133.06 sec]: learning rate : 0.000013 loss : 0.249809 +[20:10:47.747] iteration 79400 [224.06 sec]: learning rate : 0.000013 loss : 0.301013 +[20:12:18.690] iteration 79500 [315.00 sec]: learning rate : 0.000013 loss : 0.359399 +[20:13:49.594] iteration 79600 [405.90 sec]: learning rate : 0.000013 loss : 0.589247 +[20:15:20.589] iteration 79700 [496.90 sec]: learning rate : 0.000013 loss : 0.689512 +[20:16:51.619] iteration 79800 [587.93 sec]: learning rate : 0.000013 loss : 0.597914 +[20:18:22.525] iteration 79900 [678.83 sec]: learning rate : 0.000013 loss : 0.368884 +[20:19:53.479] iteration 80000 [769.79 sec]: learning rate : 0.000003 loss : 0.613454 +[20:19:53.635] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_80000.pth +[20:21:24.551] iteration 80100 [860.86 sec]: learning rate : 0.000006 loss : 0.409689 +[20:22:55.495] iteration 80200 [951.80 sec]: learning rate : 0.000006 loss : 0.357319 +[20:24:26.468] iteration 80300 [1042.78 sec]: learning rate : 0.000006 loss : 0.818689 +[20:25:57.381] iteration 80400 [1133.69 sec]: learning rate : 0.000006 loss : 0.455555 +[20:27:28.361] iteration 80500 [1224.67 sec]: learning rate : 0.000006 loss : 0.732814 +[20:28:59.325] iteration 80600 [1315.63 sec]: learning rate : 0.000006 loss : 0.542416 +[20:30:30.234] iteration 80700 [1406.54 sec]: learning rate : 0.000006 loss : 0.527890 +[20:32:01.187] iteration 80800 [1497.50 sec]: learning rate : 0.000006 loss : 0.330062 +[20:33:32.090] iteration 80900 [1588.40 sec]: learning rate : 0.000006 loss : 0.545569 +[20:35:03.043] iteration 81000 [1679.35 sec]: learning rate : 0.000006 loss : 0.455424 +[20:36:34.009] iteration 81100 [1770.32 sec]: learning rate : 0.000006 loss : 0.585406 +[20:38:04.926] iteration 81200 [1861.23 sec]: learning rate : 0.000006 loss : 0.460711 +[20:38:38.543] Epoch 38 Evaluation: +[20:39:28.894] average MSE: 0.03838975355029106 average PSNR: 29.860925256647814 average SSIM: 0.7367035320484399 +[20:40:26.456] iteration 81300 [57.50 sec]: learning rate : 0.000006 loss : 0.388001 +[20:41:57.493] iteration 81400 [148.53 sec]: learning rate : 0.000006 loss : 0.461190 +[20:43:28.421] iteration 81500 [239.46 sec]: learning rate : 0.000006 loss : 0.443614 +[20:44:59.353] iteration 81600 [330.39 sec]: learning rate : 0.000006 loss : 0.456548 +[20:46:30.317] iteration 81700 [421.35 sec]: learning rate : 0.000006 loss : 0.383725 +[20:48:01.232] iteration 81800 [512.27 sec]: learning rate : 0.000006 loss : 0.517592 +[20:49:32.195] iteration 81900 [603.23 sec]: learning rate : 0.000006 loss : 0.658453 +[20:51:03.108] iteration 82000 [694.15 sec]: learning rate : 0.000006 loss : 0.628136 +[20:52:34.082] iteration 82100 [785.12 sec]: learning rate : 0.000006 loss : 0.689390 +[20:54:04.996] iteration 82200 [876.04 sec]: learning rate : 0.000006 loss : 0.318416 +[20:55:35.898] iteration 82300 [966.94 sec]: learning rate : 0.000006 loss : 0.740428 +[20:57:06.850] iteration 82400 [1057.89 sec]: learning rate : 0.000006 loss : 0.557476 +[20:58:37.797] iteration 82500 [1148.84 sec]: learning rate : 0.000006 loss : 0.664973 +[21:00:08.706] iteration 82600 [1239.74 sec]: learning rate : 0.000006 loss : 0.763912 +[21:01:39.677] iteration 82700 [1330.71 sec]: learning rate : 0.000006 loss : 0.418551 +[21:03:10.660] iteration 82800 [1421.70 sec]: learning rate : 0.000006 loss : 0.507538 +[21:04:41.568] iteration 82900 [1512.61 sec]: learning rate : 0.000006 loss : 0.569803 +[21:06:12.544] iteration 83000 [1603.58 sec]: learning rate : 0.000006 loss : 0.363999 +[21:07:43.451] iteration 83100 [1694.49 sec]: learning rate : 0.000006 loss : 0.363872 +[21:09:14.428] iteration 83200 [1785.47 sec]: learning rate : 0.000006 loss : 0.522915 +[21:10:45.393] iteration 83300 [1876.43 sec]: learning rate : 0.000006 loss : 0.559017 +[21:11:03.552] Epoch 39 Evaluation: +[21:11:53.856] average MSE: 0.03837361931800842 average PSNR: 29.865615061347473 average SSIM: 0.7369288986631557 +[21:13:06.890] iteration 83400 [72.97 sec]: learning rate : 0.000006 loss : 0.640183 +[21:14:37.901] iteration 83500 [163.98 sec]: learning rate : 0.000006 loss : 0.354179 +[21:16:08.866] iteration 83600 [254.94 sec]: learning rate : 0.000006 loss : 0.589438 +[21:17:39.776] iteration 83700 [345.85 sec]: learning rate : 0.000006 loss : 0.484867 +[21:19:10.742] iteration 83800 [436.82 sec]: learning rate : 0.000006 loss : 0.454130 +[21:20:41.707] iteration 83900 [527.78 sec]: learning rate : 0.000006 loss : 0.435028 +[21:22:12.648] iteration 84000 [618.73 sec]: learning rate : 0.000006 loss : 0.998619 +[21:23:43.640] iteration 84100 [709.72 sec]: learning rate : 0.000006 loss : 0.788246 +[21:25:14.545] iteration 84200 [800.62 sec]: learning rate : 0.000006 loss : 0.503057 +[21:26:45.494] iteration 84300 [891.57 sec]: learning rate : 0.000006 loss : 0.477442 +[21:28:16.445] iteration 84400 [982.53 sec]: learning rate : 0.000006 loss : 0.555778 +[21:29:47.372] iteration 84500 [1073.45 sec]: learning rate : 0.000006 loss : 0.443515 +[21:31:18.353] iteration 84600 [1164.43 sec]: learning rate : 0.000006 loss : 0.663052 +[21:32:49.313] iteration 84700 [1255.39 sec]: learning rate : 0.000006 loss : 0.327761 +[21:34:20.250] iteration 84800 [1346.33 sec]: learning rate : 0.000006 loss : 0.532682 +[21:35:51.233] iteration 84900 [1437.31 sec]: learning rate : 0.000006 loss : 0.442781 +[21:37:22.213] iteration 85000 [1528.29 sec]: learning rate : 0.000006 loss : 0.466585 +[21:38:53.123] iteration 85100 [1619.20 sec]: learning rate : 0.000006 loss : 0.440053 +[21:40:24.135] iteration 85200 [1710.21 sec]: learning rate : 0.000006 loss : 0.642921 +[21:41:55.049] iteration 85300 [1801.13 sec]: learning rate : 0.000006 loss : 0.831399 +[21:43:25.992] iteration 85400 [1892.07 sec]: learning rate : 0.000006 loss : 0.323704 +[21:43:28.692] Epoch 40 Evaluation: +[21:44:19.697] average MSE: 0.0382712222635746 average PSNR: 29.8785111613328 average SSIM: 0.7369000798418859 +[21:45:48.262] iteration 85500 [88.50 sec]: learning rate : 0.000006 loss : 0.666675 +[21:47:19.163] iteration 85600 [179.40 sec]: learning rate : 0.000006 loss : 0.594603 +[21:48:50.105] iteration 85700 [270.34 sec]: learning rate : 0.000006 loss : 0.344925 +[21:50:21.064] iteration 85800 [361.30 sec]: learning rate : 0.000006 loss : 0.895176 +[21:51:51.959] iteration 85900 [452.19 sec]: learning rate : 0.000006 loss : 0.395032 +[21:53:22.923] iteration 86000 [543.16 sec]: learning rate : 0.000006 loss : 0.411622 +[21:54:53.897] iteration 86100 [634.13 sec]: learning rate : 0.000006 loss : 0.509571 +[21:56:24.804] iteration 86200 [725.04 sec]: learning rate : 0.000006 loss : 0.665362 +[21:57:55.780] iteration 86300 [816.02 sec]: learning rate : 0.000006 loss : 0.374717 +[21:59:26.747] iteration 86400 [906.98 sec]: learning rate : 0.000006 loss : 0.394423 +[22:00:57.667] iteration 86500 [997.90 sec]: learning rate : 0.000006 loss : 0.548579 +[22:02:28.645] iteration 86600 [1088.88 sec]: learning rate : 0.000006 loss : 0.416206 +[22:03:59.575] iteration 86700 [1179.81 sec]: learning rate : 0.000006 loss : 0.307228 +[22:05:30.488] iteration 86800 [1270.72 sec]: learning rate : 0.000006 loss : 0.297318 +[22:07:01.458] iteration 86900 [1361.69 sec]: learning rate : 0.000006 loss : 0.601228 +[22:08:32.359] iteration 87000 [1452.59 sec]: learning rate : 0.000006 loss : 0.333359 +[22:10:03.331] iteration 87100 [1543.57 sec]: learning rate : 0.000006 loss : 0.424934 +[22:11:34.321] iteration 87200 [1634.56 sec]: learning rate : 0.000006 loss : 0.446973 +[22:13:05.241] iteration 87300 [1725.48 sec]: learning rate : 0.000006 loss : 0.500978 +[22:14:36.211] iteration 87400 [1816.45 sec]: learning rate : 0.000006 loss : 0.604088 +[22:15:54.365] Epoch 41 Evaluation: +[22:16:44.911] average MSE: 0.038314688950777054 average PSNR: 29.874670321541412 average SSIM: 0.7371417623906594 +[22:16:57.940] iteration 87500 [12.96 sec]: learning rate : 0.000006 loss : 0.549397 +[22:18:28.870] iteration 87600 [103.89 sec]: learning rate : 0.000006 loss : 0.786232 +[22:19:59.843] iteration 87700 [194.87 sec]: learning rate : 0.000006 loss : 0.337051 +[22:21:30.746] iteration 87800 [285.77 sec]: learning rate : 0.000006 loss : 0.456645 +[22:23:01.705] iteration 87900 [376.73 sec]: learning rate : 0.000006 loss : 0.399393 +[22:24:32.677] iteration 88000 [467.70 sec]: learning rate : 0.000006 loss : 0.375256 +[22:26:03.579] iteration 88100 [558.60 sec]: learning rate : 0.000006 loss : 0.444442 +[22:27:34.537] iteration 88200 [649.56 sec]: learning rate : 0.000006 loss : 0.516579 +[22:29:05.499] iteration 88300 [740.52 sec]: learning rate : 0.000006 loss : 0.396885 +[22:30:36.408] iteration 88400 [831.43 sec]: learning rate : 0.000006 loss : 0.604066 +[22:32:07.341] iteration 88500 [922.36 sec]: learning rate : 0.000006 loss : 0.338514 +[22:33:38.324] iteration 88600 [1013.35 sec]: learning rate : 0.000006 loss : 0.528564 +[22:35:09.237] iteration 88700 [1104.26 sec]: learning rate : 0.000006 loss : 0.431522 +[22:36:40.195] iteration 88800 [1195.22 sec]: learning rate : 0.000006 loss : 0.362996 +[22:38:11.105] iteration 88900 [1286.13 sec]: learning rate : 0.000006 loss : 0.545539 +[22:39:42.088] iteration 89000 [1377.11 sec]: learning rate : 0.000006 loss : 0.483888 +[22:41:13.048] iteration 89100 [1468.07 sec]: learning rate : 0.000006 loss : 0.353376 +[22:42:43.942] iteration 89200 [1558.96 sec]: learning rate : 0.000006 loss : 0.343143 +[22:44:14.890] iteration 89300 [1649.91 sec]: learning rate : 0.000006 loss : 0.673450 +[22:45:45.863] iteration 89400 [1740.88 sec]: learning rate : 0.000006 loss : 0.477650 +[22:47:16.773] iteration 89500 [1831.79 sec]: learning rate : 0.000006 loss : 0.313194 +[22:48:19.528] Epoch 42 Evaluation: +[22:49:11.673] average MSE: 0.03844667971134186 average PSNR: 29.857066538548906 average SSIM: 0.7371895503851883 +[22:49:40.150] iteration 89600 [28.41 sec]: learning rate : 0.000006 loss : 0.435233 +[22:51:11.166] iteration 89700 [119.43 sec]: learning rate : 0.000006 loss : 0.482490 +[22:52:42.067] iteration 89800 [210.33 sec]: learning rate : 0.000006 loss : 0.443359 +[22:54:13.025] iteration 89900 [301.28 sec]: learning rate : 0.000006 loss : 0.426150 +[22:55:43.926] iteration 90000 [392.19 sec]: learning rate : 0.000006 loss : 0.438549 +[22:57:14.971] iteration 90100 [483.23 sec]: learning rate : 0.000006 loss : 0.530999 +[22:58:45.952] iteration 90200 [574.21 sec]: learning rate : 0.000006 loss : 0.432206 +[23:00:16.881] iteration 90300 [665.14 sec]: learning rate : 0.000006 loss : 0.452101 +[23:01:47.858] iteration 90400 [756.12 sec]: learning rate : 0.000006 loss : 0.391254 +[23:03:18.835] iteration 90500 [847.09 sec]: learning rate : 0.000006 loss : 0.478590 +[23:04:49.736] iteration 90600 [938.00 sec]: learning rate : 0.000006 loss : 0.464989 +[23:06:20.694] iteration 90700 [1028.95 sec]: learning rate : 0.000006 loss : 0.759592 +[23:07:51.658] iteration 90800 [1119.92 sec]: learning rate : 0.000006 loss : 0.454873 +[23:09:22.580] iteration 90900 [1210.85 sec]: learning rate : 0.000006 loss : 0.575918 +[23:10:53.578] iteration 91000 [1301.84 sec]: learning rate : 0.000006 loss : 0.411908 +[23:12:24.486] iteration 91100 [1392.76 sec]: learning rate : 0.000006 loss : 0.505361 +[23:13:55.403] iteration 91200 [1483.66 sec]: learning rate : 0.000006 loss : 0.503747 +[23:15:26.371] iteration 91300 [1574.63 sec]: learning rate : 0.000006 loss : 0.397601 +[23:16:57.288] iteration 91400 [1665.55 sec]: learning rate : 0.000006 loss : 0.707378 +[23:18:28.276] iteration 91500 [1756.54 sec]: learning rate : 0.000006 loss : 0.450827 +[23:19:59.192] iteration 91600 [1847.45 sec]: learning rate : 0.000006 loss : 0.783265 +[23:20:46.514] Epoch 43 Evaluation: +[23:21:37.220] average MSE: 0.038330718874931335 average PSNR: 29.87013321473967 average SSIM: 0.7370529495242165 +[23:22:21.175] iteration 91700 [43.89 sec]: learning rate : 0.000006 loss : 0.657843 +[23:23:52.199] iteration 91800 [134.91 sec]: learning rate : 0.000006 loss : 0.298030 +[23:25:23.119] iteration 91900 [225.83 sec]: learning rate : 0.000006 loss : 0.405729 +[23:26:54.089] iteration 92000 [316.80 sec]: learning rate : 0.000006 loss : 0.643779 +[23:28:25.070] iteration 92100 [407.78 sec]: learning rate : 0.000006 loss : 0.634827 +[23:29:55.977] iteration 92200 [498.69 sec]: learning rate : 0.000006 loss : 0.535469 +[23:31:26.965] iteration 92300 [589.68 sec]: learning rate : 0.000006 loss : 0.372873 +[23:32:57.942] iteration 92400 [680.66 sec]: learning rate : 0.000006 loss : 0.373855 +[23:34:28.861] iteration 92500 [771.57 sec]: learning rate : 0.000006 loss : 0.728354 +[23:35:59.836] iteration 92600 [862.55 sec]: learning rate : 0.000006 loss : 0.323327 +[23:37:30.766] iteration 92700 [953.48 sec]: learning rate : 0.000006 loss : 0.419133 +[23:39:01.703] iteration 92800 [1044.42 sec]: learning rate : 0.000006 loss : 0.634437 +[23:40:32.682] iteration 92900 [1135.39 sec]: learning rate : 0.000006 loss : 0.537041 +[23:42:03.591] iteration 93000 [1226.30 sec]: learning rate : 0.000006 loss : 0.671591 +[23:43:34.561] iteration 93100 [1317.27 sec]: learning rate : 0.000006 loss : 0.449413 +[23:45:05.511] iteration 93200 [1408.22 sec]: learning rate : 0.000006 loss : 0.420714 +[23:46:36.419] iteration 93300 [1499.13 sec]: learning rate : 0.000006 loss : 0.329618 +[23:48:07.390] iteration 93400 [1590.10 sec]: learning rate : 0.000006 loss : 0.527158 +[23:49:38.355] iteration 93500 [1681.07 sec]: learning rate : 0.000006 loss : 0.345479 +[23:51:09.269] iteration 93600 [1771.98 sec]: learning rate : 0.000006 loss : 0.528243 +[23:52:40.249] iteration 93700 [1862.96 sec]: learning rate : 0.000006 loss : 0.455755 +[23:53:12.054] Epoch 44 Evaluation: +[23:54:02.670] average MSE: 0.0383526086807251 average PSNR: 29.869200501857282 average SSIM: 0.7370974477716143 +[23:55:02.167] iteration 93800 [59.43 sec]: learning rate : 0.000006 loss : 0.548104 +[23:56:33.063] iteration 93900 [150.33 sec]: learning rate : 0.000006 loss : 0.334103 +[23:58:04.021] iteration 94000 [241.28 sec]: learning rate : 0.000006 loss : 0.402768 +[23:59:34.923] iteration 94100 [332.19 sec]: learning rate : 0.000006 loss : 0.224838 +[00:01:05.832] iteration 94200 [423.09 sec]: learning rate : 0.000006 loss : 0.303955 +[00:02:36.800] iteration 94300 [514.06 sec]: learning rate : 0.000006 loss : 0.266140 +[00:04:07.734] iteration 94400 [605.01 sec]: learning rate : 0.000006 loss : 0.365971 +[00:05:38.707] iteration 94500 [695.97 sec]: learning rate : 0.000006 loss : 0.461634 +[00:07:09.687] iteration 94600 [786.95 sec]: learning rate : 0.000006 loss : 0.686620 +[00:08:40.591] iteration 94700 [877.85 sec]: learning rate : 0.000006 loss : 0.638951 +[00:10:11.558] iteration 94800 [968.82 sec]: learning rate : 0.000006 loss : 0.400566 +[00:11:42.523] iteration 94900 [1059.85 sec]: learning rate : 0.000006 loss : 0.376825 +[00:13:13.454] iteration 95000 [1150.72 sec]: learning rate : 0.000006 loss : 0.563678 +[00:14:44.385] iteration 95100 [1241.65 sec]: learning rate : 0.000006 loss : 0.624233 +[00:16:15.289] iteration 95200 [1332.55 sec]: learning rate : 0.000006 loss : 0.462889 +[00:17:46.197] iteration 95300 [1423.46 sec]: learning rate : 0.000006 loss : 0.407967 +[00:19:17.153] iteration 95400 [1514.42 sec]: learning rate : 0.000006 loss : 0.438302 +[00:20:48.054] iteration 95500 [1605.32 sec]: learning rate : 0.000006 loss : 0.310125 +[00:22:19.029] iteration 95600 [1696.29 sec]: learning rate : 0.000006 loss : 0.378368 +[00:23:49.959] iteration 95700 [1787.22 sec]: learning rate : 0.000006 loss : 0.761212 +[00:25:20.847] iteration 95800 [1878.11 sec]: learning rate : 0.000006 loss : 0.657146 +[00:25:37.174] Epoch 45 Evaluation: +[00:26:28.676] average MSE: 0.03840780630707741 average PSNR: 29.860176914051618 average SSIM: 0.7375492524915651 +[00:27:43.621] iteration 95900 [74.88 sec]: learning rate : 0.000006 loss : 0.460988 +[00:29:14.570] iteration 96000 [165.82 sec]: learning rate : 0.000006 loss : 0.471035 +[00:30:45.471] iteration 96100 [256.73 sec]: learning rate : 0.000006 loss : 0.389062 +[00:32:16.401] iteration 96200 [347.66 sec]: learning rate : 0.000006 loss : 0.482353 +[00:33:47.333] iteration 96300 [438.59 sec]: learning rate : 0.000006 loss : 0.523305 +[00:35:18.269] iteration 96400 [529.52 sec]: learning rate : 0.000006 loss : 0.423349 +[00:36:49.219] iteration 96500 [620.47 sec]: learning rate : 0.000006 loss : 0.345884 +[00:38:20.111] iteration 96600 [711.37 sec]: learning rate : 0.000006 loss : 0.467321 +[00:39:51.047] iteration 96700 [802.30 sec]: learning rate : 0.000006 loss : 0.495488 +[00:41:21.975] iteration 96800 [893.23 sec]: learning rate : 0.000006 loss : 0.465610 +[00:42:52.856] iteration 96900 [984.11 sec]: learning rate : 0.000006 loss : 0.391495 +[00:44:23.791] iteration 97000 [1075.05 sec]: learning rate : 0.000006 loss : 0.441412 +[00:45:54.698] iteration 97100 [1165.95 sec]: learning rate : 0.000006 loss : 0.300890 +[00:47:25.642] iteration 97200 [1256.90 sec]: learning rate : 0.000006 loss : 0.404443 +[00:48:56.572] iteration 97300 [1347.83 sec]: learning rate : 0.000006 loss : 0.594813 +[00:50:27.462] iteration 97400 [1438.72 sec]: learning rate : 0.000006 loss : 0.505692 +[00:51:58.391] iteration 97500 [1529.65 sec]: learning rate : 0.000006 loss : 0.359663 +[00:53:29.301] iteration 97600 [1620.56 sec]: learning rate : 0.000006 loss : 0.302452 +[00:55:00.203] iteration 97700 [1711.46 sec]: learning rate : 0.000006 loss : 0.629164 +[00:56:31.152] iteration 97800 [1802.41 sec]: learning rate : 0.000006 loss : 0.446843 +[00:58:02.099] iteration 97900 [1893.35 sec]: learning rate : 0.000006 loss : 0.579801 +[00:58:02.977] Epoch 46 Evaluation: +[00:58:52.731] average MSE: 0.038291942328214645 average PSNR: 29.874235605510005 average SSIM: 0.7375116846004657 +[01:00:23.013] iteration 98000 [90.21 sec]: learning rate : 0.000006 loss : 0.498144 +[01:01:54.004] iteration 98100 [181.20 sec]: learning rate : 0.000006 loss : 0.487991 +[01:03:24.970] iteration 98200 [272.24 sec]: learning rate : 0.000006 loss : 0.400900 +[01:04:55.866] iteration 98300 [363.07 sec]: learning rate : 0.000006 loss : 0.442160 +[01:06:26.811] iteration 98400 [454.01 sec]: learning rate : 0.000006 loss : 0.654456 +[01:07:57.726] iteration 98500 [544.93 sec]: learning rate : 0.000006 loss : 0.335293 +[01:09:28.698] iteration 98600 [635.90 sec]: learning rate : 0.000006 loss : 0.437777 +[01:10:59.639] iteration 98700 [726.84 sec]: learning rate : 0.000006 loss : 0.461514 +[01:12:30.538] iteration 98800 [817.74 sec]: learning rate : 0.000006 loss : 0.459173 +[01:14:01.441] iteration 98900 [908.64 sec]: learning rate : 0.000006 loss : 0.519033 +[01:15:32.401] iteration 99000 [999.60 sec]: learning rate : 0.000006 loss : 0.493714 +[01:17:03.310] iteration 99100 [1090.51 sec]: learning rate : 0.000006 loss : 0.528906 +[01:18:34.247] iteration 99200 [1181.45 sec]: learning rate : 0.000006 loss : 0.290732 +[01:20:05.161] iteration 99300 [1272.36 sec]: learning rate : 0.000006 loss : 0.379288 +[01:21:36.121] iteration 99400 [1363.32 sec]: learning rate : 0.000006 loss : 0.630965 +[01:23:07.039] iteration 99500 [1454.24 sec]: learning rate : 0.000006 loss : 0.590325 +[01:24:37.927] iteration 99600 [1545.13 sec]: learning rate : 0.000006 loss : 0.475650 +[01:26:08.865] iteration 99700 [1636.06 sec]: learning rate : 0.000006 loss : 0.659469 +[01:27:39.810] iteration 99800 [1727.01 sec]: learning rate : 0.000006 loss : 0.537113 +[01:29:10.712] iteration 99900 [1817.91 sec]: learning rate : 0.000006 loss : 0.527537 +[01:30:27.091] Epoch 47 Evaluation: +[01:31:17.350] average MSE: 0.038327161222696304 average PSNR: 29.872974372125327 average SSIM: 0.7370267023459122 +[01:31:32.189] iteration 100000 [14.77 sec]: learning rate : 0.000002 loss : 0.635106 +[01:31:32.344] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_100000.pth +[01:31:33.225] Epoch 48 Evaluation: +[01:32:23.445] average MSE: 0.03833882138133049 average PSNR: 29.87054699465335 average SSIM: 0.7373177569397077 +[01:32:23.714] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_100000.pth diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/log/events.out.tfevents.1752647681.GCRSANDBOX133 b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/log/events.out.tfevents.1752647681.GCRSANDBOX133 new file mode 100644 index 0000000000000000000000000000000000000000..168efad1f21a4738b680af50e8923ce524df337a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/log/events.out.tfevents.1752647681.GCRSANDBOX133 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:126a8b89c3cfadf585cf7729bd1fabaee5faf29dffe843618dc35bf44129e8c8 +size 40 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_/best_checkpoint.pth b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_/best_checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..7db7c356b34dc1bc2ff2b58454550c3f7bc9308e --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_/best_checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cf4c5b56459ec6a65e93165ea87a4b60b58aa0ad62699c975c761694654cd41 +size 56614874 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_/log.txt b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..5bd27db841ecbdbc9ca3988933576254d21532cd --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_/log.txt @@ -0,0 +1,1135 @@ +[05:57:09.149] Namespace(root_path='/home/v-qichen3/MRI_recon/data/fastmri', MRIDOWN='4X', low_field_SNR=0, phase='train', gpu='0', exp='FSMNet_fastmri_4x', max_iterations=100000, batch_size=4, base_lr=0.0001, seed=1337, resume=None, relation_consistency='False', clip_grad='True', norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='random', CENTER_FRACTIONS=[0.08], ACCELERATIONS=[4], num_timesteps=5, image_size=320, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[05:58:48.694] iteration 100 [93.29 sec]: learning rate : 0.000100 loss : 0.918586 +[06:00:19.661] iteration 200 [184.26 sec]: learning rate : 0.000100 loss : 0.900872 +[06:01:50.565] iteration 300 [275.16 sec]: learning rate : 0.000100 loss : 0.953547 +[06:03:21.498] iteration 400 [366.09 sec]: learning rate : 0.000100 loss : 0.427337 +[06:04:52.533] iteration 500 [457.13 sec]: learning rate : 0.000100 loss : 1.143515 +[06:06:23.440] iteration 600 [548.04 sec]: learning rate : 0.000100 loss : 0.819123 +[06:07:54.401] iteration 700 [639.00 sec]: learning rate : 0.000100 loss : 0.626560 +[06:09:25.381] iteration 800 [729.98 sec]: learning rate : 0.000100 loss : 0.820221 +[06:10:56.288] iteration 900 [820.89 sec]: learning rate : 0.000100 loss : 0.816579 +[06:12:27.251] iteration 1000 [911.85 sec]: learning rate : 0.000100 loss : 0.650628 +[06:13:58.168] iteration 1100 [1002.77 sec]: learning rate : 0.000100 loss : 0.656048 +[06:15:29.137] iteration 1200 [1093.73 sec]: learning rate : 0.000100 loss : 1.078188 +[06:17:00.093] iteration 1300 [1184.69 sec]: learning rate : 0.000100 loss : 0.648531 +[06:18:31.004] iteration 1400 [1275.60 sec]: learning rate : 0.000100 loss : 0.740182 +[06:20:01.957] iteration 1500 [1366.55 sec]: learning rate : 0.000100 loss : 0.559478 +[06:21:32.909] iteration 1600 [1457.51 sec]: learning rate : 0.000100 loss : 0.549957 +[06:23:03.819] iteration 1700 [1548.42 sec]: learning rate : 0.000100 loss : 1.227873 +[06:24:34.793] iteration 1800 [1639.39 sec]: learning rate : 0.000100 loss : 1.007279 +[06:26:05.711] iteration 1900 [1730.31 sec]: learning rate : 0.000100 loss : 0.706404 +[06:27:36.615] iteration 2000 [1821.21 sec]: learning rate : 0.000100 loss : 0.775094 +[06:28:52.083] Epoch 0 Evaluation: +[06:29:44.455] average MSE: 0.05376233533024788 average PSNR: 28.011207135521175 average SSIM: 0.6753641708606896 +[06:30:00.156] iteration 2100 [15.63 sec]: learning rate : 0.000100 loss : 0.938743 +[06:31:31.169] iteration 2200 [106.65 sec]: learning rate : 0.000100 loss : 0.470765 +[06:33:02.070] iteration 2300 [197.55 sec]: learning rate : 0.000100 loss : 0.546204 +[06:34:33.028] iteration 2400 [288.51 sec]: learning rate : 0.000100 loss : 0.464111 +[06:36:03.984] iteration 2500 [379.46 sec]: learning rate : 0.000100 loss : 0.763503 +[06:37:34.895] iteration 2600 [470.37 sec]: learning rate : 0.000100 loss : 0.817069 +[06:39:05.892] iteration 2700 [561.37 sec]: learning rate : 0.000100 loss : 0.516133 +[06:40:36.853] iteration 2800 [652.33 sec]: learning rate : 0.000100 loss : 0.601265 +[06:42:07.748] iteration 2900 [743.23 sec]: learning rate : 0.000100 loss : 1.035532 +[06:43:38.684] iteration 3000 [834.16 sec]: learning rate : 0.000100 loss : 0.575478 +[06:45:09.578] iteration 3100 [925.06 sec]: learning rate : 0.000100 loss : 0.937929 +[06:46:40.526] iteration 3200 [1016.00 sec]: learning rate : 0.000100 loss : 0.619612 +[06:48:11.464] iteration 3300 [1106.94 sec]: learning rate : 0.000100 loss : 0.584772 +[06:49:42.359] iteration 3400 [1197.84 sec]: learning rate : 0.000100 loss : 0.660041 +[06:51:13.264] iteration 3500 [1288.74 sec]: learning rate : 0.000100 loss : 0.567186 +[06:52:44.203] iteration 3600 [1379.68 sec]: learning rate : 0.000100 loss : 0.471382 +[06:54:15.111] iteration 3700 [1470.59 sec]: learning rate : 0.000100 loss : 0.434392 +[06:55:46.066] iteration 3800 [1561.55 sec]: learning rate : 0.000100 loss : 0.780614 +[06:57:16.967] iteration 3900 [1652.44 sec]: learning rate : 0.000100 loss : 0.506792 +[06:58:47.910] iteration 4000 [1743.39 sec]: learning rate : 0.000100 loss : 0.609247 +[07:00:18.858] iteration 4100 [1834.34 sec]: learning rate : 0.000100 loss : 0.817836 +[07:01:18.837] Epoch 1 Evaluation: +[07:02:12.023] average MSE: 0.05338025093078613 average PSNR: 28.088190127494332 average SSIM: 0.6978901765515252 +[07:02:43.225] iteration 4200 [31.14 sec]: learning rate : 0.000100 loss : 0.713182 +[07:04:14.268] iteration 4300 [122.18 sec]: learning rate : 0.000100 loss : 0.887023 +[07:05:45.239] iteration 4400 [213.15 sec]: learning rate : 0.000100 loss : 0.523889 +[07:07:16.144] iteration 4500 [304.05 sec]: learning rate : 0.000100 loss : 0.790979 +[07:08:47.055] iteration 4600 [394.96 sec]: learning rate : 0.000100 loss : 0.365239 +[07:10:18.001] iteration 4700 [485.91 sec]: learning rate : 0.000100 loss : 0.368119 +[07:11:48.897] iteration 4800 [576.81 sec]: learning rate : 0.000100 loss : 0.696716 +[07:13:19.850] iteration 4900 [667.76 sec]: learning rate : 0.000100 loss : 0.457110 +[07:14:50.742] iteration 5000 [758.65 sec]: learning rate : 0.000100 loss : 0.365158 +[07:16:21.659] iteration 5100 [849.57 sec]: learning rate : 0.000100 loss : 0.427239 +[07:17:52.592] iteration 5200 [940.50 sec]: learning rate : 0.000100 loss : 0.832262 +[07:19:23.494] iteration 5300 [1031.40 sec]: learning rate : 0.000100 loss : 0.539947 +[07:20:54.444] iteration 5400 [1122.35 sec]: learning rate : 0.000100 loss : 0.678343 +[07:22:25.405] iteration 5500 [1213.31 sec]: learning rate : 0.000100 loss : 0.882720 +[07:23:56.300] iteration 5600 [1304.21 sec]: learning rate : 0.000100 loss : 0.534896 +[07:25:27.262] iteration 5700 [1395.17 sec]: learning rate : 0.000100 loss : 0.591627 +[07:26:58.170] iteration 5800 [1486.08 sec]: learning rate : 0.000100 loss : 0.429107 +[07:28:29.072] iteration 5900 [1576.98 sec]: learning rate : 0.000100 loss : 0.494989 +[07:30:00.025] iteration 6000 [1667.93 sec]: learning rate : 0.000100 loss : 0.784026 +[07:31:30.970] iteration 6100 [1758.88 sec]: learning rate : 0.000100 loss : 0.436495 +[07:33:01.865] iteration 6200 [1849.78 sec]: learning rate : 0.000100 loss : 0.577678 +[07:33:46.385] Epoch 2 Evaluation: +[07:34:36.986] average MSE: 0.04912297800183296 average PSNR: 28.514683170126947 average SSIM: 0.7075998745003204 +[07:35:23.714] iteration 6300 [46.66 sec]: learning rate : 0.000100 loss : 0.735472 +[07:36:54.592] iteration 6400 [137.54 sec]: learning rate : 0.000100 loss : 0.309129 +[07:38:25.527] iteration 6500 [228.47 sec]: learning rate : 0.000100 loss : 0.859254 +[07:39:56.456] iteration 6600 [319.40 sec]: learning rate : 0.000100 loss : 0.335407 +[07:41:27.332] iteration 6700 [410.28 sec]: learning rate : 0.000100 loss : 0.535852 +[07:42:58.220] iteration 6800 [501.17 sec]: learning rate : 0.000100 loss : 0.481983 +[07:44:29.107] iteration 6900 [592.06 sec]: learning rate : 0.000100 loss : 0.528085 +[07:46:00.084] iteration 7000 [683.03 sec]: learning rate : 0.000100 loss : 0.434656 +[07:47:31.010] iteration 7100 [773.96 sec]: learning rate : 0.000100 loss : 0.623011 +[07:49:01.882] iteration 7200 [864.83 sec]: learning rate : 0.000100 loss : 0.409784 +[07:50:32.837] iteration 7300 [955.78 sec]: learning rate : 0.000100 loss : 0.750982 +[07:52:03.765] iteration 7400 [1046.71 sec]: learning rate : 0.000100 loss : 0.411081 +[07:53:34.639] iteration 7500 [1137.59 sec]: learning rate : 0.000100 loss : 0.669726 +[07:55:05.569] iteration 7600 [1228.52 sec]: learning rate : 0.000100 loss : 0.481082 +[07:56:36.504] iteration 7700 [1319.45 sec]: learning rate : 0.000100 loss : 0.669825 +[07:58:07.392] iteration 7800 [1410.34 sec]: learning rate : 0.000100 loss : 0.568463 +[07:59:38.340] iteration 7900 [1501.29 sec]: learning rate : 0.000100 loss : 0.519216 +[08:01:09.231] iteration 8000 [1592.18 sec]: learning rate : 0.000100 loss : 0.500009 +[08:02:40.175] iteration 8100 [1683.12 sec]: learning rate : 0.000100 loss : 0.412099 +[08:04:11.112] iteration 8200 [1774.06 sec]: learning rate : 0.000100 loss : 0.479369 +[08:05:42.010] iteration 8300 [1864.96 sec]: learning rate : 0.000100 loss : 0.507192 +[08:06:11.261] Epoch 3 Evaluation: +[08:07:03.724] average MSE: 0.04785247519612312 average PSNR: 28.673283606504953 average SSIM: 0.7077329499338162 +[08:08:05.798] iteration 8400 [62.01 sec]: learning rate : 0.000100 loss : 0.511157 +[08:09:36.778] iteration 8500 [152.99 sec]: learning rate : 0.000100 loss : 0.374643 +[08:11:07.665] iteration 8600 [243.87 sec]: learning rate : 0.000100 loss : 0.670366 +[08:12:38.567] iteration 8700 [334.78 sec]: learning rate : 0.000100 loss : 0.352006 +[08:14:09.445] iteration 8800 [425.66 sec]: learning rate : 0.000100 loss : 0.321436 +[08:15:40.377] iteration 8900 [516.59 sec]: learning rate : 0.000100 loss : 0.612779 +[08:17:11.274] iteration 9000 [607.48 sec]: learning rate : 0.000100 loss : 0.848694 +[08:18:42.154] iteration 9100 [698.36 sec]: learning rate : 0.000100 loss : 0.245037 +[08:20:13.087] iteration 9200 [789.30 sec]: learning rate : 0.000100 loss : 0.329296 +[08:21:44.022] iteration 9300 [880.23 sec]: learning rate : 0.000100 loss : 0.444504 +[08:23:14.911] iteration 9400 [971.12 sec]: learning rate : 0.000100 loss : 0.729806 +[08:24:45.834] iteration 9500 [1062.04 sec]: learning rate : 0.000100 loss : 0.464908 +[08:26:16.767] iteration 9600 [1152.98 sec]: learning rate : 0.000100 loss : 0.899030 +[08:27:47.652] iteration 9700 [1243.86 sec]: learning rate : 0.000100 loss : 0.340530 +[08:29:18.589] iteration 9800 [1334.80 sec]: learning rate : 0.000100 loss : 0.451247 +[08:30:49.463] iteration 9900 [1425.67 sec]: learning rate : 0.000100 loss : 0.354662 +[08:32:20.397] iteration 10000 [1516.61 sec]: learning rate : 0.000100 loss : 0.970937 +[08:33:51.324] iteration 10100 [1607.53 sec]: learning rate : 0.000100 loss : 0.439643 +[08:35:22.197] iteration 10200 [1698.42 sec]: learning rate : 0.000100 loss : 0.680983 +[08:36:53.136] iteration 10300 [1789.34 sec]: learning rate : 0.000100 loss : 0.447379 +[08:38:24.091] iteration 10400 [1880.30 sec]: learning rate : 0.000100 loss : 0.459428 +[08:38:37.708] Epoch 4 Evaluation: +[08:39:28.504] average MSE: 0.04786675050854683 average PSNR: 28.69952803516013 average SSIM: 0.7116247860605504 +[08:40:45.997] iteration 10500 [77.43 sec]: learning rate : 0.000100 loss : 0.682684 +[08:42:16.949] iteration 10600 [168.38 sec]: learning rate : 0.000100 loss : 0.639254 +[08:43:47.933] iteration 10700 [259.36 sec]: learning rate : 0.000100 loss : 0.397533 +[08:45:18.817] iteration 10800 [350.25 sec]: learning rate : 0.000100 loss : 0.643468 +[08:46:49.759] iteration 10900 [441.19 sec]: learning rate : 0.000100 loss : 0.554986 +[08:48:20.648] iteration 11000 [532.08 sec]: learning rate : 0.000100 loss : 0.511688 +[08:49:51.601] iteration 11100 [623.03 sec]: learning rate : 0.000100 loss : 0.380322 +[08:51:22.481] iteration 11200 [713.91 sec]: learning rate : 0.000100 loss : 0.364020 +[08:52:53.358] iteration 11300 [804.79 sec]: learning rate : 0.000100 loss : 0.547036 +[08:54:24.312] iteration 11400 [895.74 sec]: learning rate : 0.000100 loss : 0.332018 +[08:55:55.246] iteration 11500 [986.67 sec]: learning rate : 0.000100 loss : 0.653914 +[08:57:26.124] iteration 11600 [1077.55 sec]: learning rate : 0.000100 loss : 0.465628 +[08:58:57.068] iteration 11700 [1168.50 sec]: learning rate : 0.000100 loss : 0.328156 +[09:00:27.953] iteration 11800 [1259.38 sec]: learning rate : 0.000100 loss : 0.532797 +[09:01:58.878] iteration 11900 [1350.31 sec]: learning rate : 0.000100 loss : 0.605334 +[09:03:29.849] iteration 12000 [1441.28 sec]: learning rate : 0.000100 loss : 0.434149 +[09:05:00.730] iteration 12100 [1532.16 sec]: learning rate : 0.000100 loss : 0.453853 +[09:06:31.670] iteration 12200 [1623.10 sec]: learning rate : 0.000100 loss : 0.815695 +[09:08:02.598] iteration 12300 [1714.03 sec]: learning rate : 0.000100 loss : 0.522098 +[09:09:33.472] iteration 12400 [1804.90 sec]: learning rate : 0.000100 loss : 0.646133 +[09:11:02.547] Epoch 5 Evaluation: +[09:11:53.377] average MSE: 0.04653492569923401 average PSNR: 28.842373522965143 average SSIM: 0.7144229692357346 +[09:11:55.498] iteration 12500 [2.05 sec]: learning rate : 0.000100 loss : 0.398501 +[09:13:26.371] iteration 12600 [92.93 sec]: learning rate : 0.000100 loss : 0.228429 +[09:14:57.354] iteration 12700 [183.91 sec]: learning rate : 0.000100 loss : 0.312375 +[09:16:28.283] iteration 12800 [274.84 sec]: learning rate : 0.000100 loss : 0.457767 +[09:17:59.164] iteration 12900 [365.72 sec]: learning rate : 0.000100 loss : 0.345610 +[09:19:30.093] iteration 13000 [456.65 sec]: learning rate : 0.000100 loss : 0.640983 +[09:21:01.020] iteration 13100 [547.58 sec]: learning rate : 0.000100 loss : 0.712869 +[09:22:31.912] iteration 13200 [638.47 sec]: learning rate : 0.000100 loss : 0.503441 +[09:24:02.841] iteration 13300 [729.40 sec]: learning rate : 0.000100 loss : 0.601807 +[09:25:33.765] iteration 13400 [820.32 sec]: learning rate : 0.000100 loss : 0.630568 +[09:27:04.663] iteration 13500 [911.22 sec]: learning rate : 0.000100 loss : 0.636090 +[09:28:35.621] iteration 13600 [1002.18 sec]: learning rate : 0.000100 loss : 0.508772 +[09:30:06.504] iteration 13700 [1093.06 sec]: learning rate : 0.000100 loss : 0.864401 +[09:31:37.471] iteration 13800 [1184.03 sec]: learning rate : 0.000100 loss : 0.370363 +[09:33:08.382] iteration 13900 [1274.94 sec]: learning rate : 0.000100 loss : 0.403852 +[09:34:39.279] iteration 14000 [1365.84 sec]: learning rate : 0.000100 loss : 0.595511 +[09:36:10.230] iteration 14100 [1456.79 sec]: learning rate : 0.000100 loss : 0.626785 +[09:37:41.185] iteration 14200 [1547.74 sec]: learning rate : 0.000100 loss : 0.298140 +[09:39:12.091] iteration 14300 [1638.65 sec]: learning rate : 0.000100 loss : 0.501705 +[09:40:43.034] iteration 14400 [1729.59 sec]: learning rate : 0.000100 loss : 0.317928 +[09:42:13.913] iteration 14500 [1820.47 sec]: learning rate : 0.000100 loss : 0.197809 +[09:43:27.511] Epoch 6 Evaluation: +[09:44:20.365] average MSE: 0.045401304960250854 average PSNR: 28.964724276122272 average SSIM: 0.7165064516601239 +[09:44:37.899] iteration 14600 [17.47 sec]: learning rate : 0.000100 loss : 0.479049 +[09:46:08.879] iteration 14700 [108.45 sec]: learning rate : 0.000100 loss : 0.441909 +[09:47:39.777] iteration 14800 [199.35 sec]: learning rate : 0.000100 loss : 0.252382 +[09:49:10.718] iteration 14900 [290.29 sec]: learning rate : 0.000100 loss : 0.554265 +[09:50:41.689] iteration 15000 [381.26 sec]: learning rate : 0.000100 loss : 0.450411 +[09:52:12.601] iteration 15100 [472.17 sec]: learning rate : 0.000100 loss : 0.785007 +[09:53:43.554] iteration 15200 [563.12 sec]: learning rate : 0.000100 loss : 0.452197 +[09:55:14.502] iteration 15300 [654.07 sec]: learning rate : 0.000100 loss : 0.458436 +[09:56:45.406] iteration 15400 [744.97 sec]: learning rate : 0.000100 loss : 0.497077 +[09:58:16.320] iteration 15500 [835.89 sec]: learning rate : 0.000100 loss : 0.292793 +[09:59:47.227] iteration 15600 [926.79 sec]: learning rate : 0.000100 loss : 1.375526 +[10:01:18.208] iteration 15700 [1017.78 sec]: learning rate : 0.000100 loss : 0.678350 +[10:02:49.165] iteration 15800 [1108.73 sec]: learning rate : 0.000100 loss : 0.407348 +[10:04:20.074] iteration 15900 [1199.64 sec]: learning rate : 0.000100 loss : 0.458810 +[10:05:50.976] iteration 16000 [1290.54 sec]: learning rate : 0.000100 loss : 0.423216 +[10:07:21.893] iteration 16100 [1381.46 sec]: learning rate : 0.000100 loss : 0.558472 +[10:08:52.863] iteration 16200 [1472.43 sec]: learning rate : 0.000100 loss : 0.594152 +[10:10:23.817] iteration 16300 [1563.39 sec]: learning rate : 0.000100 loss : 0.532991 +[10:11:54.730] iteration 16400 [1654.30 sec]: learning rate : 0.000100 loss : 0.471776 +[10:13:25.693] iteration 16500 [1745.26 sec]: learning rate : 0.000100 loss : 0.622658 +[10:14:56.651] iteration 16600 [1836.22 sec]: learning rate : 0.000100 loss : 0.583850 +[10:15:54.786] Epoch 7 Evaluation: +[10:16:45.803] average MSE: 0.043812498450279236 average PSNR: 29.137799492167506 average SSIM: 0.7184595439080809 +[10:17:18.809] iteration 16700 [32.94 sec]: learning rate : 0.000100 loss : 0.528670 +[10:18:49.805] iteration 16800 [123.94 sec]: learning rate : 0.000100 loss : 0.641011 +[10:20:20.755] iteration 16900 [214.88 sec]: learning rate : 0.000100 loss : 0.600476 +[10:21:51.682] iteration 17000 [305.81 sec]: learning rate : 0.000100 loss : 0.775819 +[10:23:22.644] iteration 17100 [396.77 sec]: learning rate : 0.000100 loss : 0.366822 +[10:24:53.605] iteration 17200 [487.80 sec]: learning rate : 0.000100 loss : 0.519383 +[10:26:24.504] iteration 17300 [578.64 sec]: learning rate : 0.000100 loss : 0.594435 +[10:27:55.477] iteration 17400 [669.61 sec]: learning rate : 0.000100 loss : 0.296491 +[10:29:26.391] iteration 17500 [760.52 sec]: learning rate : 0.000100 loss : 0.442178 +[10:30:57.352] iteration 17600 [851.48 sec]: learning rate : 0.000100 loss : 0.377246 +[10:32:28.318] iteration 17700 [942.45 sec]: learning rate : 0.000100 loss : 0.611244 +[10:33:59.220] iteration 17800 [1033.35 sec]: learning rate : 0.000100 loss : 0.241646 +[10:35:30.192] iteration 17900 [1124.32 sec]: learning rate : 0.000100 loss : 0.443607 +[10:37:01.150] iteration 18000 [1215.28 sec]: learning rate : 0.000100 loss : 0.638310 +[10:38:32.058] iteration 18100 [1306.19 sec]: learning rate : 0.000100 loss : 0.420632 +[10:40:03.040] iteration 18200 [1397.17 sec]: learning rate : 0.000100 loss : 0.476457 +[10:41:33.952] iteration 18300 [1488.08 sec]: learning rate : 0.000100 loss : 0.396274 +[10:43:04.935] iteration 18400 [1579.07 sec]: learning rate : 0.000100 loss : 0.389291 +[10:44:35.934] iteration 18500 [1670.06 sec]: learning rate : 0.000100 loss : 0.561026 +[10:46:06.819] iteration 18600 [1760.95 sec]: learning rate : 0.000100 loss : 0.490854 +[10:47:37.759] iteration 18700 [1851.89 sec]: learning rate : 0.000100 loss : 0.378235 +[10:48:20.449] Epoch 8 Evaluation: +[10:49:13.230] average MSE: 0.043516892939805984 average PSNR: 29.21377825534631 average SSIM: 0.7215472669797558 +[10:50:01.798] iteration 18800 [48.50 sec]: learning rate : 0.000100 loss : 0.488972 +[10:51:32.701] iteration 18900 [139.40 sec]: learning rate : 0.000100 loss : 0.370990 +[10:53:03.649] iteration 19000 [230.35 sec]: learning rate : 0.000100 loss : 0.671016 +[10:54:34.587] iteration 19100 [321.29 sec]: learning rate : 0.000100 loss : 0.805975 +[10:56:05.502] iteration 19200 [412.21 sec]: learning rate : 0.000100 loss : 0.598290 +[10:57:36.460] iteration 19300 [503.16 sec]: learning rate : 0.000100 loss : 0.810359 +[10:59:07.363] iteration 19400 [594.07 sec]: learning rate : 0.000100 loss : 0.230088 +[11:00:38.326] iteration 19500 [685.03 sec]: learning rate : 0.000100 loss : 0.418767 +[11:02:09.294] iteration 19600 [776.00 sec]: learning rate : 0.000100 loss : 0.483577 +[11:03:40.202] iteration 19700 [866.91 sec]: learning rate : 0.000100 loss : 0.544363 +[11:05:11.161] iteration 19800 [957.86 sec]: learning rate : 0.000100 loss : 0.537095 +[11:06:42.127] iteration 19900 [1048.83 sec]: learning rate : 0.000100 loss : 0.239555 +[11:08:13.044] iteration 20000 [1139.75 sec]: learning rate : 0.000025 loss : 0.590714 +[11:08:13.202] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_20000.pth +[11:09:44.174] iteration 20100 [1230.88 sec]: learning rate : 0.000050 loss : 0.301037 +[11:11:15.139] iteration 20200 [1321.84 sec]: learning rate : 0.000050 loss : 0.501123 +[11:12:46.059] iteration 20300 [1412.76 sec]: learning rate : 0.000050 loss : 0.410261 +[11:14:17.005] iteration 20400 [1503.71 sec]: learning rate : 0.000050 loss : 0.260349 +[11:15:47.922] iteration 20500 [1594.62 sec]: learning rate : 0.000050 loss : 0.760140 +[11:17:18.889] iteration 20600 [1685.59 sec]: learning rate : 0.000050 loss : 0.736865 +[11:18:49.829] iteration 20700 [1776.53 sec]: learning rate : 0.000050 loss : 0.452833 +[11:20:20.724] iteration 20800 [1867.43 sec]: learning rate : 0.000050 loss : 0.821110 +[11:20:47.963] Epoch 9 Evaluation: +[11:21:38.457] average MSE: 0.04381684586405754 average PSNR: 29.208554831302305 average SSIM: 0.7225160127143108 +[11:22:42.543] iteration 20900 [64.02 sec]: learning rate : 0.000050 loss : 0.727167 +[11:24:13.488] iteration 21000 [154.96 sec]: learning rate : 0.000050 loss : 0.473049 +[11:25:44.380] iteration 21100 [245.86 sec]: learning rate : 0.000050 loss : 0.564684 +[11:27:15.342] iteration 21200 [336.82 sec]: learning rate : 0.000050 loss : 0.690081 +[11:28:46.236] iteration 21300 [427.71 sec]: learning rate : 0.000050 loss : 0.684332 +[11:30:17.179] iteration 21400 [518.66 sec]: learning rate : 0.000050 loss : 0.406607 +[11:31:48.147] iteration 21500 [609.62 sec]: learning rate : 0.000050 loss : 0.389439 +[11:33:19.049] iteration 21600 [700.53 sec]: learning rate : 0.000050 loss : 0.263998 +[11:34:49.990] iteration 21700 [791.47 sec]: learning rate : 0.000050 loss : 0.702631 +[11:36:20.956] iteration 21800 [882.43 sec]: learning rate : 0.000050 loss : 0.234967 +[11:37:51.878] iteration 21900 [973.35 sec]: learning rate : 0.000050 loss : 0.611985 +[11:39:22.832] iteration 22000 [1064.31 sec]: learning rate : 0.000050 loss : 0.623830 +[11:40:53.816] iteration 22100 [1155.29 sec]: learning rate : 0.000050 loss : 0.553418 +[11:42:24.729] iteration 22200 [1246.21 sec]: learning rate : 0.000050 loss : 0.411929 +[11:43:55.749] iteration 22300 [1337.24 sec]: learning rate : 0.000050 loss : 0.464443 +[11:45:26.672] iteration 22400 [1428.15 sec]: learning rate : 0.000050 loss : 0.484764 +[11:46:57.631] iteration 22500 [1519.11 sec]: learning rate : 0.000050 loss : 0.704150 +[11:48:28.577] iteration 22600 [1610.05 sec]: learning rate : 0.000050 loss : 0.578821 +[11:49:59.465] iteration 22700 [1700.94 sec]: learning rate : 0.000050 loss : 0.320296 +[11:51:30.410] iteration 22800 [1791.89 sec]: learning rate : 0.000050 loss : 0.491242 +[11:53:01.383] iteration 22900 [1882.86 sec]: learning rate : 0.000050 loss : 0.471562 +[11:53:13.175] Epoch 10 Evaluation: +[11:54:03.917] average MSE: 0.04263733699917793 average PSNR: 29.33246923201823 average SSIM: 0.7240476148097196 +[11:55:23.221] iteration 23000 [79.24 sec]: learning rate : 0.000050 loss : 0.758568 +[11:56:54.209] iteration 23100 [170.22 sec]: learning rate : 0.000050 loss : 0.408618 +[11:58:25.140] iteration 23200 [261.16 sec]: learning rate : 0.000050 loss : 0.292340 +[11:59:56.011] iteration 23300 [352.03 sec]: learning rate : 0.000050 loss : 0.592343 +[12:01:26.897] iteration 23400 [442.91 sec]: learning rate : 0.000050 loss : 0.688621 +[12:02:57.839] iteration 23500 [533.85 sec]: learning rate : 0.000050 loss : 0.409481 +[12:04:28.726] iteration 23600 [624.74 sec]: learning rate : 0.000050 loss : 0.556420 +[12:05:59.641] iteration 23700 [715.66 sec]: learning rate : 0.000050 loss : 0.470565 +[12:07:30.514] iteration 23800 [806.53 sec]: learning rate : 0.000050 loss : 0.582921 +[12:09:01.497] iteration 23900 [897.51 sec]: learning rate : 0.000050 loss : 0.493330 +[12:10:32.456] iteration 24000 [988.47 sec]: learning rate : 0.000050 loss : 0.690928 +[12:12:03.343] iteration 24100 [1079.36 sec]: learning rate : 0.000050 loss : 0.566781 +[12:13:34.264] iteration 24200 [1170.28 sec]: learning rate : 0.000050 loss : 0.281355 +[12:15:05.200] iteration 24300 [1261.28 sec]: learning rate : 0.000050 loss : 0.385564 +[12:16:36.098] iteration 24400 [1352.11 sec]: learning rate : 0.000050 loss : 0.554366 +[12:18:07.020] iteration 24500 [1443.04 sec]: learning rate : 0.000050 loss : 0.464525 +[12:19:37.906] iteration 24600 [1533.92 sec]: learning rate : 0.000050 loss : 0.514983 +[12:21:08.850] iteration 24700 [1624.87 sec]: learning rate : 0.000050 loss : 0.574364 +[12:22:39.781] iteration 24800 [1715.80 sec]: learning rate : 0.000050 loss : 0.869691 +[12:24:10.669] iteration 24900 [1806.68 sec]: learning rate : 0.000050 loss : 0.726824 +[12:25:37.952] Epoch 11 Evaluation: +[12:26:28.691] average MSE: 0.04385831952095032 average PSNR: 29.226189240782272 average SSIM: 0.7230770171897762 +[12:26:32.592] iteration 25000 [3.83 sec]: learning rate : 0.000050 loss : 0.326366 +[12:28:03.572] iteration 25100 [94.81 sec]: learning rate : 0.000050 loss : 0.268886 +[12:29:34.452] iteration 25200 [185.69 sec]: learning rate : 0.000050 loss : 0.488752 +[12:31:05.375] iteration 25300 [276.62 sec]: learning rate : 0.000050 loss : 0.509218 +[12:32:36.249] iteration 25400 [367.49 sec]: learning rate : 0.000050 loss : 0.444319 +[12:34:07.175] iteration 25500 [458.42 sec]: learning rate : 0.000050 loss : 0.470949 +[12:35:38.054] iteration 25600 [549.30 sec]: learning rate : 0.000050 loss : 0.658491 +[12:37:08.925] iteration 25700 [640.17 sec]: learning rate : 0.000050 loss : 0.778599 +[12:38:39.852] iteration 25800 [731.09 sec]: learning rate : 0.000050 loss : 0.480007 +[12:40:10.781] iteration 25900 [822.02 sec]: learning rate : 0.000050 loss : 0.418207 +[12:41:41.656] iteration 26000 [912.90 sec]: learning rate : 0.000050 loss : 0.545843 +[12:43:12.587] iteration 26100 [1003.83 sec]: learning rate : 0.000050 loss : 0.549182 +[12:44:43.525] iteration 26200 [1094.77 sec]: learning rate : 0.000050 loss : 0.798014 +[12:46:14.422] iteration 26300 [1185.66 sec]: learning rate : 0.000050 loss : 0.558371 +[12:47:45.369] iteration 26400 [1276.61 sec]: learning rate : 0.000050 loss : 0.591150 +[12:49:16.309] iteration 26500 [1367.55 sec]: learning rate : 0.000050 loss : 0.406899 +[12:50:47.176] iteration 26600 [1458.42 sec]: learning rate : 0.000050 loss : 0.403253 +[12:52:18.097] iteration 26700 [1549.34 sec]: learning rate : 0.000050 loss : 0.277255 +[12:53:48.975] iteration 26800 [1640.22 sec]: learning rate : 0.000050 loss : 0.340818 +[12:55:19.898] iteration 26900 [1731.14 sec]: learning rate : 0.000050 loss : 0.440352 +[12:56:50.805] iteration 27000 [1822.05 sec]: learning rate : 0.000050 loss : 0.424771 +[12:58:02.582] Epoch 12 Evaluation: +[12:58:53.214] average MSE: 0.043065521866083145 average PSNR: 29.336877218546984 average SSIM: 0.7261972505837041 +[12:59:12.567] iteration 27100 [19.29 sec]: learning rate : 0.000050 loss : 0.623675 +[13:00:43.560] iteration 27200 [110.28 sec]: learning rate : 0.000050 loss : 0.502105 +[13:02:14.470] iteration 27300 [201.19 sec]: learning rate : 0.000050 loss : 0.455799 +[13:03:45.457] iteration 27400 [292.18 sec]: learning rate : 0.000050 loss : 0.381470 +[13:05:16.374] iteration 27500 [383.09 sec]: learning rate : 0.000050 loss : 0.239280 +[13:06:47.261] iteration 27600 [473.98 sec]: learning rate : 0.000050 loss : 0.508281 +[13:08:18.141] iteration 27700 [564.86 sec]: learning rate : 0.000050 loss : 0.427217 +[13:09:49.094] iteration 27800 [655.81 sec]: learning rate : 0.000050 loss : 0.462340 +[13:11:19.983] iteration 27900 [746.70 sec]: learning rate : 0.000050 loss : 0.618945 +[13:12:50.910] iteration 28000 [837.63 sec]: learning rate : 0.000050 loss : 0.810660 +[13:14:21.833] iteration 28100 [928.55 sec]: learning rate : 0.000050 loss : 0.492696 +[13:15:52.699] iteration 28200 [1019.42 sec]: learning rate : 0.000050 loss : 0.555932 +[13:17:23.611] iteration 28300 [1110.33 sec]: learning rate : 0.000050 loss : 0.468674 +[13:18:54.490] iteration 28400 [1201.21 sec]: learning rate : 0.000050 loss : 0.399580 +[13:20:25.371] iteration 28500 [1292.09 sec]: learning rate : 0.000050 loss : 0.479880 +[13:21:56.308] iteration 28600 [1383.03 sec]: learning rate : 0.000050 loss : 0.386059 +[13:23:27.178] iteration 28700 [1473.90 sec]: learning rate : 0.000050 loss : 0.356572 +[13:24:58.062] iteration 28800 [1564.78 sec]: learning rate : 0.000050 loss : 0.389644 +[13:26:28.981] iteration 28900 [1655.70 sec]: learning rate : 0.000050 loss : 0.717193 +[13:27:59.853] iteration 29000 [1746.57 sec]: learning rate : 0.000050 loss : 0.638473 +[13:29:30.783] iteration 29100 [1837.50 sec]: learning rate : 0.000050 loss : 0.384198 +[13:30:27.095] Epoch 13 Evaluation: +[13:31:17.823] average MSE: 0.04206184670329094 average PSNR: 29.418786831381436 average SSIM: 0.7271767294163953 +[13:31:52.744] iteration 29200 [34.86 sec]: learning rate : 0.000050 loss : 0.313878 +[13:33:23.631] iteration 29300 [125.74 sec]: learning rate : 0.000050 loss : 0.266682 +[13:34:54.583] iteration 29400 [216.69 sec]: learning rate : 0.000050 loss : 0.599184 +[13:36:25.493] iteration 29500 [307.60 sec]: learning rate : 0.000050 loss : 0.497429 +[13:37:56.461] iteration 29600 [398.57 sec]: learning rate : 0.000050 loss : 0.605465 +[13:39:27.410] iteration 29700 [489.52 sec]: learning rate : 0.000050 loss : 0.447275 +[13:40:58.319] iteration 29800 [580.43 sec]: learning rate : 0.000050 loss : 0.312045 +[13:42:29.265] iteration 29900 [671.38 sec]: learning rate : 0.000050 loss : 0.427569 +[13:44:00.208] iteration 30000 [762.32 sec]: learning rate : 0.000050 loss : 0.659923 +[13:45:31.131] iteration 30100 [853.24 sec]: learning rate : 0.000050 loss : 0.879931 +[13:47:02.117] iteration 30200 [944.23 sec]: learning rate : 0.000050 loss : 0.383771 +[13:48:33.020] iteration 30300 [1035.13 sec]: learning rate : 0.000050 loss : 0.757438 +[13:50:03.962] iteration 30400 [1126.07 sec]: learning rate : 0.000050 loss : 0.571754 +[13:51:34.880] iteration 30500 [1216.99 sec]: learning rate : 0.000050 loss : 0.374028 +[13:53:05.766] iteration 30600 [1307.87 sec]: learning rate : 0.000050 loss : 0.390218 +[13:54:36.710] iteration 30700 [1398.82 sec]: learning rate : 0.000050 loss : 0.568975 +[13:56:07.659] iteration 30800 [1489.77 sec]: learning rate : 0.000050 loss : 0.449695 +[13:57:38.553] iteration 30900 [1580.66 sec]: learning rate : 0.000050 loss : 0.651081 +[13:59:09.504] iteration 31000 [1671.61 sec]: learning rate : 0.000050 loss : 0.443550 +[14:00:40.442] iteration 31100 [1762.55 sec]: learning rate : 0.000050 loss : 0.631379 +[14:02:11.309] iteration 31200 [1853.42 sec]: learning rate : 0.000050 loss : 0.399816 +[14:02:52.220] Epoch 14 Evaluation: +[14:03:45.374] average MSE: 0.042470891028642654 average PSNR: 29.40129733421782 average SSIM: 0.7270275230083372 +[14:04:35.642] iteration 31300 [50.22 sec]: learning rate : 0.000050 loss : 0.729794 +[14:06:06.631] iteration 31400 [141.19 sec]: learning rate : 0.000050 loss : 0.369328 +[14:07:37.501] iteration 31500 [232.06 sec]: learning rate : 0.000050 loss : 0.512286 +[14:09:08.432] iteration 31600 [322.99 sec]: learning rate : 0.000050 loss : 0.528767 +[14:10:39.327] iteration 31700 [413.89 sec]: learning rate : 0.000050 loss : 0.525817 +[14:12:10.226] iteration 31800 [504.79 sec]: learning rate : 0.000050 loss : 0.600181 +[14:13:41.144] iteration 31900 [595.70 sec]: learning rate : 0.000050 loss : 0.220027 +[14:15:12.020] iteration 32000 [686.58 sec]: learning rate : 0.000050 loss : 0.519301 +[14:16:42.894] iteration 32100 [777.45 sec]: learning rate : 0.000050 loss : 0.471864 +[14:18:13.760] iteration 32200 [868.32 sec]: learning rate : 0.000050 loss : 0.657203 +[14:19:44.683] iteration 32300 [959.24 sec]: learning rate : 0.000050 loss : 0.336149 +[14:21:15.639] iteration 32400 [1050.20 sec]: learning rate : 0.000050 loss : 0.396310 +[14:22:46.526] iteration 32500 [1141.09 sec]: learning rate : 0.000050 loss : 0.482593 +[14:24:17.439] iteration 32600 [1232.00 sec]: learning rate : 0.000050 loss : 0.505965 +[14:25:48.361] iteration 32700 [1322.92 sec]: learning rate : 0.000050 loss : 0.506343 +[14:27:19.230] iteration 32800 [1413.79 sec]: learning rate : 0.000050 loss : 0.343551 +[14:28:50.150] iteration 32900 [1504.71 sec]: learning rate : 0.000050 loss : 0.271806 +[14:30:21.074] iteration 33000 [1595.63 sec]: learning rate : 0.000050 loss : 0.722987 +[14:31:51.942] iteration 33100 [1686.50 sec]: learning rate : 0.000050 loss : 0.490767 +[14:33:22.830] iteration 33200 [1777.39 sec]: learning rate : 0.000050 loss : 0.351914 +[14:34:53.699] iteration 33300 [1868.26 sec]: learning rate : 0.000050 loss : 0.392180 +[14:35:19.168] Epoch 15 Evaluation: +[14:36:09.855] average MSE: 0.041972726583480835 average PSNR: 29.467893559758352 average SSIM: 0.7287978783315362 +[14:37:15.558] iteration 33400 [65.64 sec]: learning rate : 0.000050 loss : 0.357245 +[14:38:46.500] iteration 33500 [156.58 sec]: learning rate : 0.000050 loss : 0.712925 +[14:40:17.399] iteration 33600 [247.48 sec]: learning rate : 0.000050 loss : 0.546694 +[14:41:48.364] iteration 33700 [338.44 sec]: learning rate : 0.000050 loss : 0.985825 +[14:43:19.270] iteration 33800 [429.35 sec]: learning rate : 0.000050 loss : 0.667022 +[14:44:50.248] iteration 33900 [520.33 sec]: learning rate : 0.000050 loss : 0.543863 +[14:46:21.185] iteration 34000 [611.26 sec]: learning rate : 0.000050 loss : 0.322543 +[14:47:52.066] iteration 34100 [702.14 sec]: learning rate : 0.000050 loss : 0.355203 +[14:49:22.954] iteration 34200 [793.03 sec]: learning rate : 0.000050 loss : 0.318519 +[14:50:53.891] iteration 34300 [883.97 sec]: learning rate : 0.000050 loss : 0.291514 +[14:52:24.770] iteration 34400 [974.85 sec]: learning rate : 0.000050 loss : 0.449262 +[14:53:55.713] iteration 34500 [1065.79 sec]: learning rate : 0.000050 loss : 0.476782 +[14:55:26.600] iteration 34600 [1156.68 sec]: learning rate : 0.000050 loss : 0.584548 +[14:56:57.534] iteration 34700 [1247.61 sec]: learning rate : 0.000050 loss : 0.456432 +[14:58:28.418] iteration 34800 [1338.50 sec]: learning rate : 0.000050 loss : 0.461365 +[14:59:59.292] iteration 34900 [1429.37 sec]: learning rate : 0.000050 loss : 0.423062 +[15:01:30.226] iteration 35000 [1520.30 sec]: learning rate : 0.000050 loss : 0.466051 +[15:03:01.162] iteration 35100 [1611.24 sec]: learning rate : 0.000050 loss : 0.574462 +[15:04:32.041] iteration 35200 [1702.12 sec]: learning rate : 0.000050 loss : 0.557923 +[15:06:02.976] iteration 35300 [1793.05 sec]: learning rate : 0.000050 loss : 0.528153 +[15:07:33.908] iteration 35400 [1883.99 sec]: learning rate : 0.000050 loss : 0.273321 +[15:07:43.871] Epoch 16 Evaluation: +[15:08:36.119] average MSE: 0.041751667857170105 average PSNR: 29.481927814803658 average SSIM: 0.7289790134939285 +[15:09:57.255] iteration 35500 [81.07 sec]: learning rate : 0.000050 loss : 0.426165 +[15:11:28.267] iteration 35600 [172.08 sec]: learning rate : 0.000050 loss : 0.397228 +[15:12:59.155] iteration 35700 [262.97 sec]: learning rate : 0.000050 loss : 0.328692 +[15:14:30.133] iteration 35800 [353.95 sec]: learning rate : 0.000050 loss : 0.679359 +[15:16:01.072] iteration 35900 [444.89 sec]: learning rate : 0.000050 loss : 0.697156 +[15:17:31.957] iteration 36000 [535.77 sec]: learning rate : 0.000050 loss : 0.387471 +[15:19:02.886] iteration 36100 [626.70 sec]: learning rate : 0.000050 loss : 0.293238 +[15:20:33.838] iteration 36200 [717.65 sec]: learning rate : 0.000050 loss : 0.550704 +[15:22:04.723] iteration 36300 [808.54 sec]: learning rate : 0.000050 loss : 0.449205 +[15:23:35.619] iteration 36400 [899.43 sec]: learning rate : 0.000050 loss : 0.410365 +[15:25:06.561] iteration 36500 [990.38 sec]: learning rate : 0.000050 loss : 0.572751 +[15:26:37.443] iteration 36600 [1081.26 sec]: learning rate : 0.000050 loss : 0.354622 +[15:28:08.389] iteration 36700 [1172.20 sec]: learning rate : 0.000050 loss : 0.363953 +[15:29:39.271] iteration 36800 [1263.09 sec]: learning rate : 0.000050 loss : 0.546070 +[15:31:10.206] iteration 36900 [1354.02 sec]: learning rate : 0.000050 loss : 0.612722 +[15:32:41.154] iteration 37000 [1444.97 sec]: learning rate : 0.000050 loss : 0.703887 +[15:34:12.033] iteration 37100 [1535.85 sec]: learning rate : 0.000050 loss : 0.403538 +[15:35:42.977] iteration 37200 [1626.79 sec]: learning rate : 0.000050 loss : 0.466787 +[15:37:13.869] iteration 37300 [1717.68 sec]: learning rate : 0.000050 loss : 0.328523 +[15:38:44.810] iteration 37400 [1808.62 sec]: learning rate : 0.000050 loss : 0.265131 +[15:40:10.232] Epoch 17 Evaluation: +[15:41:02.758] average MSE: 0.04174785315990448 average PSNR: 29.482428098400238 average SSIM: 0.728556093191531 +[15:41:08.473] iteration 37500 [5.65 sec]: learning rate : 0.000050 loss : 0.137526 +[15:42:39.358] iteration 37600 [96.53 sec]: learning rate : 0.000050 loss : 0.695897 +[15:44:10.354] iteration 37700 [187.53 sec]: learning rate : 0.000050 loss : 0.598504 +[15:45:41.312] iteration 37800 [278.49 sec]: learning rate : 0.000050 loss : 0.288981 +[15:47:12.212] iteration 37900 [369.39 sec]: learning rate : 0.000050 loss : 0.571875 +[15:48:43.096] iteration 38000 [460.27 sec]: learning rate : 0.000050 loss : 0.429666 +[15:50:14.019] iteration 38100 [551.19 sec]: learning rate : 0.000050 loss : 0.512453 +[15:51:44.890] iteration 38200 [642.07 sec]: learning rate : 0.000050 loss : 0.770353 +[15:53:15.818] iteration 38300 [732.99 sec]: learning rate : 0.000050 loss : 0.575242 +[15:54:46.735] iteration 38400 [823.98 sec]: learning rate : 0.000050 loss : 0.660823 +[15:56:17.605] iteration 38500 [914.78 sec]: learning rate : 0.000050 loss : 0.640580 +[15:57:48.526] iteration 38600 [1005.70 sec]: learning rate : 0.000050 loss : 0.434405 +[15:59:19.396] iteration 38700 [1096.57 sec]: learning rate : 0.000050 loss : 0.534410 +[16:00:50.309] iteration 38800 [1187.48 sec]: learning rate : 0.000050 loss : 0.823132 +[16:02:21.238] iteration 38900 [1278.41 sec]: learning rate : 0.000050 loss : 0.331396 +[16:03:52.110] iteration 39000 [1369.29 sec]: learning rate : 0.000050 loss : 0.519016 +[16:05:23.033] iteration 39100 [1460.21 sec]: learning rate : 0.000050 loss : 0.366910 +[16:06:53.952] iteration 39200 [1551.13 sec]: learning rate : 0.000050 loss : 0.529782 +[16:08:24.835] iteration 39300 [1642.01 sec]: learning rate : 0.000050 loss : 0.503890 +[16:09:55.774] iteration 39400 [1732.95 sec]: learning rate : 0.000050 loss : 0.450049 +[16:11:26.648] iteration 39500 [1823.82 sec]: learning rate : 0.000050 loss : 0.693563 +[16:12:36.650] Epoch 18 Evaluation: +[16:13:28.951] average MSE: 0.041311558336019516 average PSNR: 29.52444628321888 average SSIM: 0.7310337429538688 +[16:13:50.131] iteration 39600 [21.11 sec]: learning rate : 0.000050 loss : 0.486365 +[16:15:21.113] iteration 39700 [112.10 sec]: learning rate : 0.000050 loss : 0.420140 +[16:16:51.988] iteration 39800 [202.97 sec]: learning rate : 0.000050 loss : 0.522145 +[16:18:22.874] iteration 39900 [293.86 sec]: learning rate : 0.000050 loss : 0.414826 +[16:19:53.797] iteration 40000 [384.78 sec]: learning rate : 0.000013 loss : 0.522843 +[16:19:53.951] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_40000.pth +[16:21:24.834] iteration 40100 [475.82 sec]: learning rate : 0.000025 loss : 0.597409 +[16:22:55.775] iteration 40200 [566.76 sec]: learning rate : 0.000025 loss : 0.451749 +[16:24:26.701] iteration 40300 [657.68 sec]: learning rate : 0.000025 loss : 0.540704 +[16:25:57.571] iteration 40400 [748.55 sec]: learning rate : 0.000025 loss : 0.290944 +[16:27:28.514] iteration 40500 [839.50 sec]: learning rate : 0.000025 loss : 0.578395 +[16:28:59.444] iteration 40600 [930.43 sec]: learning rate : 0.000025 loss : 0.519286 +[16:30:30.316] iteration 40700 [1021.30 sec]: learning rate : 0.000025 loss : 0.628440 +[16:32:01.245] iteration 40800 [1112.23 sec]: learning rate : 0.000025 loss : 0.343391 +[16:33:32.130] iteration 40900 [1203.11 sec]: learning rate : 0.000025 loss : 0.298904 +[16:35:03.079] iteration 41000 [1294.06 sec]: learning rate : 0.000025 loss : 0.430690 +[16:36:33.960] iteration 41100 [1384.94 sec]: learning rate : 0.000025 loss : 0.461388 +[16:38:04.833] iteration 41200 [1475.81 sec]: learning rate : 0.000025 loss : 0.548180 +[16:39:35.769] iteration 41300 [1566.75 sec]: learning rate : 0.000025 loss : 0.455486 +[16:41:06.670] iteration 41400 [1657.65 sec]: learning rate : 0.000025 loss : 0.683852 +[16:42:37.554] iteration 41500 [1748.54 sec]: learning rate : 0.000025 loss : 0.663641 +[16:44:08.493] iteration 41600 [1839.48 sec]: learning rate : 0.000025 loss : 0.575790 +[16:45:02.993] Epoch 19 Evaluation: +[16:45:54.430] average MSE: 0.04164968430995941 average PSNR: 29.516813877101754 average SSIM: 0.730607253937201 +[16:46:31.058] iteration 41700 [36.56 sec]: learning rate : 0.000025 loss : 0.144291 +[16:48:02.002] iteration 41800 [127.51 sec]: learning rate : 0.000025 loss : 0.701640 +[16:49:32.932] iteration 41900 [218.44 sec]: learning rate : 0.000025 loss : 0.608262 +[16:51:03.817] iteration 42000 [309.32 sec]: learning rate : 0.000025 loss : 0.638085 +[16:52:34.751] iteration 42100 [400.25 sec]: learning rate : 0.000025 loss : 0.545066 +[16:54:05.643] iteration 42200 [491.15 sec]: learning rate : 0.000025 loss : 0.618842 +[16:55:36.512] iteration 42300 [582.02 sec]: learning rate : 0.000025 loss : 0.420827 +[16:57:07.449] iteration 42400 [672.95 sec]: learning rate : 0.000025 loss : 0.455019 +[16:58:38.322] iteration 42500 [763.82 sec]: learning rate : 0.000025 loss : 0.631161 +[17:00:09.251] iteration 42600 [854.75 sec]: learning rate : 0.000025 loss : 0.596887 +[17:01:40.182] iteration 42700 [945.68 sec]: learning rate : 0.000025 loss : 0.671845 +[17:03:11.073] iteration 42800 [1036.58 sec]: learning rate : 0.000025 loss : 0.444526 +[17:04:42.010] iteration 42900 [1127.51 sec]: learning rate : 0.000025 loss : 0.364104 +[17:06:12.924] iteration 43000 [1218.43 sec]: learning rate : 0.000025 loss : 0.384756 +[17:07:43.831] iteration 43100 [1309.35 sec]: learning rate : 0.000025 loss : 0.426239 +[17:09:14.751] iteration 43200 [1400.25 sec]: learning rate : 0.000025 loss : 0.396238 +[17:10:45.633] iteration 43300 [1491.14 sec]: learning rate : 0.000025 loss : 0.444104 +[17:12:16.573] iteration 43400 [1582.08 sec]: learning rate : 0.000025 loss : 0.569667 +[17:13:47.495] iteration 43500 [1673.00 sec]: learning rate : 0.000025 loss : 0.307132 +[17:15:18.362] iteration 43600 [1763.86 sec]: learning rate : 0.000025 loss : 0.470283 +[17:16:49.283] iteration 43700 [1854.79 sec]: learning rate : 0.000025 loss : 0.266394 +[17:17:28.343] Epoch 20 Evaluation: +[17:18:20.237] average MSE: 0.04070758447051048 average PSNR: 29.609092591309672 average SSIM: 0.7318649283463773 +[17:19:12.420] iteration 43800 [52.12 sec]: learning rate : 0.000025 loss : 0.563739 +[17:20:43.297] iteration 43900 [142.99 sec]: learning rate : 0.000025 loss : 0.611302 +[17:22:14.235] iteration 44000 [233.93 sec]: learning rate : 0.000025 loss : 0.706090 +[17:23:45.169] iteration 44100 [324.86 sec]: learning rate : 0.000025 loss : 0.190319 +[17:25:16.058] iteration 44200 [415.75 sec]: learning rate : 0.000025 loss : 0.399931 +[17:26:46.986] iteration 44300 [506.68 sec]: learning rate : 0.000025 loss : 0.722932 +[17:28:17.868] iteration 44400 [597.56 sec]: learning rate : 0.000025 loss : 0.729523 +[17:29:48.826] iteration 44500 [688.52 sec]: learning rate : 0.000025 loss : 0.844642 +[17:31:19.753] iteration 44600 [779.45 sec]: learning rate : 0.000025 loss : 0.336812 +[17:32:50.648] iteration 44700 [870.34 sec]: learning rate : 0.000025 loss : 0.388220 +[17:34:21.598] iteration 44800 [961.29 sec]: learning rate : 0.000025 loss : 0.305134 +[17:35:52.567] iteration 44900 [1052.26 sec]: learning rate : 0.000025 loss : 0.562090 +[17:37:23.437] iteration 45000 [1143.13 sec]: learning rate : 0.000025 loss : 0.434453 +[17:38:54.365] iteration 45100 [1234.06 sec]: learning rate : 0.000025 loss : 0.365072 +[17:40:25.238] iteration 45200 [1324.93 sec]: learning rate : 0.000025 loss : 0.508231 +[17:41:56.166] iteration 45300 [1415.86 sec]: learning rate : 0.000025 loss : 0.288691 +[17:43:27.100] iteration 45400 [1506.80 sec]: learning rate : 0.000025 loss : 0.685614 +[17:44:57.979] iteration 45500 [1597.67 sec]: learning rate : 0.000025 loss : 0.379057 +[17:46:28.925] iteration 45600 [1688.62 sec]: learning rate : 0.000025 loss : 0.566649 +[17:47:59.840] iteration 45700 [1779.54 sec]: learning rate : 0.000025 loss : 0.423884 +[17:49:30.740] iteration 45800 [1870.44 sec]: learning rate : 0.000025 loss : 0.624529 +[17:49:54.336] Epoch 21 Evaluation: +[17:50:46.975] average MSE: 0.040890298783779144 average PSNR: 29.598552237212647 average SSIM: 0.7316433317780927 +[17:51:54.624] iteration 45900 [67.58 sec]: learning rate : 0.000025 loss : 0.300682 +[17:53:25.523] iteration 46000 [158.48 sec]: learning rate : 0.000025 loss : 0.519454 +[17:54:56.467] iteration 46100 [249.43 sec]: learning rate : 0.000025 loss : 0.413378 +[17:56:27.378] iteration 46200 [340.34 sec]: learning rate : 0.000025 loss : 0.355643 +[17:57:58.299] iteration 46300 [431.26 sec]: learning rate : 0.000025 loss : 0.505622 +[17:59:29.255] iteration 46400 [522.21 sec]: learning rate : 0.000025 loss : 0.306475 +[18:01:00.212] iteration 46500 [613.17 sec]: learning rate : 0.000025 loss : 0.415423 +[18:02:31.121] iteration 46600 [704.08 sec]: learning rate : 0.000025 loss : 0.770897 +[18:04:02.047] iteration 46700 [795.02 sec]: learning rate : 0.000025 loss : 0.400659 +[18:05:32.975] iteration 46800 [885.93 sec]: learning rate : 0.000025 loss : 0.672254 +[18:07:03.847] iteration 46900 [976.80 sec]: learning rate : 0.000025 loss : 0.527781 +[18:08:34.772] iteration 47000 [1067.73 sec]: learning rate : 0.000025 loss : 0.424746 +[18:10:05.672] iteration 47100 [1158.63 sec]: learning rate : 0.000025 loss : 0.700533 +[18:11:36.614] iteration 47200 [1249.57 sec]: learning rate : 0.000025 loss : 0.525653 +[18:13:07.512] iteration 47300 [1340.47 sec]: learning rate : 0.000025 loss : 0.447195 +[18:14:38.404] iteration 47400 [1431.36 sec]: learning rate : 0.000025 loss : 0.683354 +[18:16:09.312] iteration 47500 [1522.27 sec]: learning rate : 0.000025 loss : 0.541568 +[18:17:40.240] iteration 47600 [1613.20 sec]: learning rate : 0.000025 loss : 0.368310 +[18:19:11.142] iteration 47700 [1704.10 sec]: learning rate : 0.000025 loss : 0.517980 +[18:20:42.088] iteration 47800 [1795.05 sec]: learning rate : 0.000025 loss : 0.611824 +[18:22:13.004] iteration 47900 [1885.96 sec]: learning rate : 0.000025 loss : 0.528164 +[18:22:21.154] Epoch 22 Evaluation: +[18:23:13.673] average MSE: 0.04122582823038101 average PSNR: 29.571263394898796 average SSIM: 0.7317430453221075 +[18:24:36.664] iteration 48000 [82.93 sec]: learning rate : 0.000025 loss : 0.449981 +[18:26:07.660] iteration 48100 [173.92 sec]: learning rate : 0.000025 loss : 0.520018 +[18:27:38.639] iteration 48200 [264.90 sec]: learning rate : 0.000025 loss : 0.592933 +[18:29:09.527] iteration 48300 [355.79 sec]: learning rate : 0.000025 loss : 0.514831 +[18:30:40.481] iteration 48400 [446.74 sec]: learning rate : 0.000025 loss : 0.499425 +[18:32:11.376] iteration 48500 [537.64 sec]: learning rate : 0.000025 loss : 0.362246 +[18:33:42.333] iteration 48600 [628.60 sec]: learning rate : 0.000025 loss : 0.431803 +[18:35:13.267] iteration 48700 [719.53 sec]: learning rate : 0.000025 loss : 0.760719 +[18:36:44.152] iteration 48800 [810.41 sec]: learning rate : 0.000025 loss : 0.419093 +[18:38:15.087] iteration 48900 [901.35 sec]: learning rate : 0.000025 loss : 0.545181 +[18:39:46.080] iteration 49000 [992.34 sec]: learning rate : 0.000025 loss : 0.422893 +[18:41:16.986] iteration 49100 [1083.25 sec]: learning rate : 0.000025 loss : 0.338261 +[18:42:47.923] iteration 49200 [1174.18 sec]: learning rate : 0.000025 loss : 0.851855 +[18:44:18.791] iteration 49300 [1265.05 sec]: learning rate : 0.000025 loss : 0.354829 +[18:45:49.705] iteration 49400 [1355.97 sec]: learning rate : 0.000025 loss : 0.463306 +[18:47:20.634] iteration 49500 [1446.89 sec]: learning rate : 0.000025 loss : 0.288037 +[18:48:51.505] iteration 49600 [1537.77 sec]: learning rate : 0.000025 loss : 0.440910 +[18:50:22.444] iteration 49700 [1628.71 sec]: learning rate : 0.000025 loss : 0.285139 +[18:51:53.409] iteration 49800 [1719.67 sec]: learning rate : 0.000025 loss : 0.333519 +[18:53:24.293] iteration 49900 [1810.55 sec]: learning rate : 0.000025 loss : 0.371849 +[18:54:47.956] Epoch 23 Evaluation: +[18:55:41.159] average MSE: 0.04065868258476257 average PSNR: 29.62770065769784 average SSIM: 0.7318551552877591 +[18:55:48.693] iteration 50000 [7.47 sec]: learning rate : 0.000025 loss : 0.268634 +[18:57:19.689] iteration 50100 [98.46 sec]: learning rate : 0.000025 loss : 0.683862 +[18:58:50.570] iteration 50200 [189.34 sec]: learning rate : 0.000025 loss : 0.475548 +[19:00:21.497] iteration 50300 [280.27 sec]: learning rate : 0.000025 loss : 0.382175 +[19:01:52.409] iteration 50400 [371.18 sec]: learning rate : 0.000025 loss : 0.370124 +[19:03:23.280] iteration 50500 [462.05 sec]: learning rate : 0.000025 loss : 0.424041 +[19:04:54.199] iteration 50600 [552.97 sec]: learning rate : 0.000025 loss : 0.569399 +[19:06:25.130] iteration 50700 [643.90 sec]: learning rate : 0.000025 loss : 0.678575 +[19:07:56.000] iteration 50800 [734.78 sec]: learning rate : 0.000025 loss : 0.760307 +[19:09:26.958] iteration 50900 [825.73 sec]: learning rate : 0.000025 loss : 0.526232 +[19:10:57.850] iteration 51000 [916.63 sec]: learning rate : 0.000025 loss : 0.704852 +[19:12:28.794] iteration 51100 [1007.57 sec]: learning rate : 0.000025 loss : 0.398371 +[19:13:59.745] iteration 51200 [1098.52 sec]: learning rate : 0.000025 loss : 0.755486 +[19:15:30.621] iteration 51300 [1189.40 sec]: learning rate : 0.000025 loss : 0.564586 +[19:17:01.543] iteration 51400 [1280.32 sec]: learning rate : 0.000025 loss : 0.401803 +[19:18:32.462] iteration 51500 [1371.24 sec]: learning rate : 0.000025 loss : 0.450399 +[19:20:03.337] iteration 51600 [1462.11 sec]: learning rate : 0.000025 loss : 0.370387 +[19:21:34.299] iteration 51700 [1553.07 sec]: learning rate : 0.000025 loss : 0.685579 +[19:23:05.245] iteration 51800 [1644.02 sec]: learning rate : 0.000025 loss : 0.603470 +[19:24:36.125] iteration 51900 [1734.90 sec]: learning rate : 0.000025 loss : 0.541179 +[19:26:07.055] iteration 52000 [1825.83 sec]: learning rate : 0.000025 loss : 0.662724 +[19:27:15.182] Epoch 24 Evaluation: +[19:28:05.891] average MSE: 0.04023030772805214 average PSNR: 29.670021674896077 average SSIM: 0.7329612137492515 +[19:28:28.983] iteration 52100 [23.02 sec]: learning rate : 0.000025 loss : 0.322152 +[19:29:59.865] iteration 52200 [113.91 sec]: learning rate : 0.000025 loss : 0.688861 +[19:31:30.816] iteration 52300 [204.86 sec]: learning rate : 0.000025 loss : 0.723343 +[19:33:01.713] iteration 52400 [295.75 sec]: learning rate : 0.000025 loss : 0.509315 +[19:34:32.657] iteration 52500 [386.70 sec]: learning rate : 0.000025 loss : 0.497554 +[19:36:03.632] iteration 52600 [477.67 sec]: learning rate : 0.000025 loss : 0.546184 +[19:37:34.519] iteration 52700 [568.56 sec]: learning rate : 0.000025 loss : 0.499853 +[19:39:05.482] iteration 52800 [659.52 sec]: learning rate : 0.000025 loss : 0.749491 +[19:40:36.419] iteration 52900 [750.46 sec]: learning rate : 0.000025 loss : 0.397556 +[19:42:07.304] iteration 53000 [841.35 sec]: learning rate : 0.000025 loss : 0.539844 +[19:43:38.309] iteration 53100 [932.35 sec]: learning rate : 0.000025 loss : 0.430647 +[19:45:09.188] iteration 53200 [1023.23 sec]: learning rate : 0.000025 loss : 0.401563 +[19:46:40.116] iteration 53300 [1114.16 sec]: learning rate : 0.000025 loss : 0.362199 +[19:48:11.053] iteration 53400 [1205.09 sec]: learning rate : 0.000025 loss : 0.496484 +[19:49:41.932] iteration 53500 [1295.97 sec]: learning rate : 0.000025 loss : 0.705795 +[19:51:12.852] iteration 53600 [1386.89 sec]: learning rate : 0.000025 loss : 0.516862 +[19:52:43.733] iteration 53700 [1477.77 sec]: learning rate : 0.000025 loss : 0.455060 +[19:54:14.606] iteration 53800 [1568.65 sec]: learning rate : 0.000025 loss : 0.428798 +[19:55:45.535] iteration 53900 [1659.58 sec]: learning rate : 0.000025 loss : 0.482474 +[19:57:16.457] iteration 54000 [1750.50 sec]: learning rate : 0.000025 loss : 0.570873 +[19:58:47.331] iteration 54100 [1841.37 sec]: learning rate : 0.000025 loss : 0.351056 +[19:59:40.066] Epoch 25 Evaluation: +[20:00:30.632] average MSE: 0.040409427136182785 average PSNR: 29.660070502240703 average SSIM: 0.733467282205697 +[20:01:09.094] iteration 54200 [38.39 sec]: learning rate : 0.000025 loss : 0.474009 +[20:02:40.095] iteration 54300 [129.40 sec]: learning rate : 0.000025 loss : 0.585167 +[20:04:10.974] iteration 54400 [220.27 sec]: learning rate : 0.000025 loss : 0.497205 +[20:05:41.901] iteration 54500 [311.20 sec]: learning rate : 0.000025 loss : 0.258852 +[20:07:12.787] iteration 54600 [402.09 sec]: learning rate : 0.000025 loss : 0.373098 +[20:08:43.715] iteration 54700 [493.02 sec]: learning rate : 0.000025 loss : 0.431234 +[20:10:14.643] iteration 54800 [583.94 sec]: learning rate : 0.000025 loss : 0.701259 +[20:11:45.548] iteration 54900 [674.85 sec]: learning rate : 0.000025 loss : 0.571677 +[20:13:16.471] iteration 55000 [765.77 sec]: learning rate : 0.000025 loss : 0.559302 +[20:14:47.397] iteration 55100 [856.70 sec]: learning rate : 0.000025 loss : 0.644899 +[20:16:18.274] iteration 55200 [947.57 sec]: learning rate : 0.000025 loss : 0.373120 +[20:17:49.202] iteration 55300 [1038.50 sec]: learning rate : 0.000025 loss : 0.604421 +[20:19:20.081] iteration 55400 [1129.38 sec]: learning rate : 0.000025 loss : 0.330750 +[20:20:51.001] iteration 55500 [1220.30 sec]: learning rate : 0.000025 loss : 0.527729 +[20:22:21.948] iteration 55600 [1311.25 sec]: learning rate : 0.000025 loss : 0.312513 +[20:23:52.822] iteration 55700 [1402.12 sec]: learning rate : 0.000025 loss : 0.413478 +[20:25:23.745] iteration 55800 [1493.05 sec]: learning rate : 0.000025 loss : 0.367029 +[20:26:54.666] iteration 55900 [1583.97 sec]: learning rate : 0.000025 loss : 0.470987 +[20:28:25.548] iteration 56000 [1674.85 sec]: learning rate : 0.000025 loss : 0.633130 +[20:29:56.471] iteration 56100 [1765.77 sec]: learning rate : 0.000025 loss : 0.493570 +[20:31:27.401] iteration 56200 [1856.70 sec]: learning rate : 0.000025 loss : 0.394760 +[20:32:04.645] Epoch 26 Evaluation: +[20:32:55.092] average MSE: 0.040384408086538315 average PSNR: 29.645839179362273 average SSIM: 0.7331557857066643 +[20:33:48.972] iteration 56300 [53.81 sec]: learning rate : 0.000025 loss : 0.686234 +[20:35:19.952] iteration 56400 [144.79 sec]: learning rate : 0.000025 loss : 0.813530 +[20:36:50.848] iteration 56500 [235.69 sec]: learning rate : 0.000025 loss : 0.440905 +[20:38:21.803] iteration 56600 [326.64 sec]: learning rate : 0.000025 loss : 0.268827 +[20:39:52.747] iteration 56700 [417.59 sec]: learning rate : 0.000025 loss : 0.448812 +[20:41:23.642] iteration 56800 [508.48 sec]: learning rate : 0.000025 loss : 0.491266 +[20:42:54.562] iteration 56900 [599.40 sec]: learning rate : 0.000025 loss : 0.640749 +[20:44:25.502] iteration 57000 [690.34 sec]: learning rate : 0.000025 loss : 0.620716 +[20:45:56.381] iteration 57100 [781.22 sec]: learning rate : 0.000025 loss : 0.303576 +[20:47:27.299] iteration 57200 [872.14 sec]: learning rate : 0.000025 loss : 0.339199 +[20:48:58.176] iteration 57300 [963.02 sec]: learning rate : 0.000025 loss : 0.276128 +[20:50:29.106] iteration 57400 [1053.95 sec]: learning rate : 0.000025 loss : 0.464205 +[20:52:00.043] iteration 57500 [1144.88 sec]: learning rate : 0.000025 loss : 0.361359 +[20:53:30.913] iteration 57600 [1235.75 sec]: learning rate : 0.000025 loss : 0.345115 +[20:55:01.806] iteration 57700 [1326.65 sec]: learning rate : 0.000025 loss : 0.441285 +[20:56:32.736] iteration 57800 [1417.58 sec]: learning rate : 0.000025 loss : 0.376943 +[20:58:03.614] iteration 57900 [1508.45 sec]: learning rate : 0.000025 loss : 0.787661 +[20:59:34.501] iteration 58000 [1599.34 sec]: learning rate : 0.000025 loss : 0.455442 +[21:01:05.422] iteration 58100 [1690.26 sec]: learning rate : 0.000025 loss : 0.795442 +[21:02:36.297] iteration 58200 [1781.14 sec]: learning rate : 0.000025 loss : 0.472276 +[21:04:07.215] iteration 58300 [1872.06 sec]: learning rate : 0.000025 loss : 0.751739 +[21:04:28.996] Epoch 27 Evaluation: +[21:05:19.794] average MSE: 0.03976333141326904 average PSNR: 29.7177832661263 average SSIM: 0.7342210564427983 +[21:06:29.104] iteration 58400 [69.24 sec]: learning rate : 0.000025 loss : 0.449914 +[21:08:00.130] iteration 58500 [160.27 sec]: learning rate : 0.000025 loss : 0.550043 +[21:09:31.059] iteration 58600 [251.20 sec]: learning rate : 0.000025 loss : 0.373810 +[21:11:01.949] iteration 58700 [342.09 sec]: learning rate : 0.000025 loss : 0.526948 +[21:12:32.920] iteration 58800 [433.06 sec]: learning rate : 0.000025 loss : 0.539195 +[21:14:03.835] iteration 58900 [523.97 sec]: learning rate : 0.000025 loss : 0.282021 +[21:15:34.726] iteration 59000 [614.87 sec]: learning rate : 0.000025 loss : 0.491887 +[21:17:05.697] iteration 59100 [705.84 sec]: learning rate : 0.000025 loss : 0.474596 +[21:18:36.573] iteration 59200 [796.71 sec]: learning rate : 0.000025 loss : 0.328289 +[21:20:07.511] iteration 59300 [887.65 sec]: learning rate : 0.000025 loss : 1.255228 +[21:21:38.473] iteration 59400 [978.61 sec]: learning rate : 0.000025 loss : 0.549678 +[21:23:09.362] iteration 59500 [1069.50 sec]: learning rate : 0.000025 loss : 0.521662 +[21:24:40.324] iteration 59600 [1160.46 sec]: learning rate : 0.000025 loss : 0.507483 +[21:26:11.271] iteration 59700 [1251.41 sec]: learning rate : 0.000025 loss : 0.498166 +[21:27:42.172] iteration 59800 [1342.31 sec]: learning rate : 0.000025 loss : 0.399032 +[21:29:13.133] iteration 59900 [1433.27 sec]: learning rate : 0.000025 loss : 0.386088 +[21:30:44.026] iteration 60000 [1524.17 sec]: learning rate : 0.000006 loss : 0.524952 +[21:30:44.304] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_60000.pth +[21:32:15.207] iteration 60100 [1615.35 sec]: learning rate : 0.000013 loss : 0.505027 +[21:33:46.157] iteration 60200 [1706.30 sec]: learning rate : 0.000013 loss : 0.428676 +[21:35:17.046] iteration 60300 [1797.19 sec]: learning rate : 0.000013 loss : 0.477906 +[21:36:47.974] iteration 60400 [1888.11 sec]: learning rate : 0.000013 loss : 0.477788 +[21:36:54.308] Epoch 28 Evaluation: +[21:37:47.749] average MSE: 0.04039575904607773 average PSNR: 29.66131735398642 average SSIM: 0.733798454716447 +[21:39:12.649] iteration 60500 [84.83 sec]: learning rate : 0.000013 loss : 0.360066 +[21:40:43.523] iteration 60600 [175.71 sec]: learning rate : 0.000013 loss : 0.269050 +[21:42:14.476] iteration 60700 [266.66 sec]: learning rate : 0.000013 loss : 0.671177 +[21:43:45.399] iteration 60800 [357.58 sec]: learning rate : 0.000013 loss : 0.760921 +[21:45:16.293] iteration 60900 [448.48 sec]: learning rate : 0.000013 loss : 0.442009 +[21:46:47.223] iteration 61000 [539.41 sec]: learning rate : 0.000013 loss : 0.258532 +[21:48:18.102] iteration 61100 [630.29 sec]: learning rate : 0.000013 loss : 0.408521 +[21:49:49.042] iteration 61200 [721.23 sec]: learning rate : 0.000013 loss : 0.518843 +[21:51:19.972] iteration 61300 [812.15 sec]: learning rate : 0.000013 loss : 0.442622 +[21:52:50.845] iteration 61400 [903.03 sec]: learning rate : 0.000013 loss : 0.384458 +[21:54:21.769] iteration 61500 [993.95 sec]: learning rate : 0.000013 loss : 0.436439 +[21:55:52.643] iteration 61600 [1084.83 sec]: learning rate : 0.000013 loss : 0.355905 +[21:57:23.557] iteration 61700 [1175.74 sec]: learning rate : 0.000013 loss : 0.498982 +[21:58:54.483] iteration 61800 [1266.67 sec]: learning rate : 0.000013 loss : 0.596821 +[22:00:25.360] iteration 61900 [1357.54 sec]: learning rate : 0.000013 loss : 0.524487 +[22:01:56.279] iteration 62000 [1448.46 sec]: learning rate : 0.000013 loss : 0.674891 +[22:03:27.173] iteration 62100 [1539.36 sec]: learning rate : 0.000013 loss : 0.493304 +[22:04:58.073] iteration 62200 [1630.26 sec]: learning rate : 0.000013 loss : 1.153564 +[22:06:29.039] iteration 62300 [1721.22 sec]: learning rate : 0.000013 loss : 0.687448 +[22:07:59.922] iteration 62400 [1812.11 sec]: learning rate : 0.000013 loss : 0.335652 +[22:09:21.780] Epoch 29 Evaluation: +[22:10:14.493] average MSE: 0.04022200405597687 average PSNR: 29.686270151307152 average SSIM: 0.7336164605828607 +[22:10:23.830] iteration 62500 [9.27 sec]: learning rate : 0.000013 loss : 0.670961 +[22:11:54.838] iteration 62600 [100.28 sec]: learning rate : 0.000013 loss : 0.566204 +[22:13:25.721] iteration 62700 [191.16 sec]: learning rate : 0.000013 loss : 0.439036 +[22:14:56.618] iteration 62800 [282.06 sec]: learning rate : 0.000013 loss : 0.604839 +[22:16:27.578] iteration 62900 [373.02 sec]: learning rate : 0.000013 loss : 0.363841 +[22:17:58.478] iteration 63000 [463.92 sec]: learning rate : 0.000013 loss : 0.285138 +[22:19:29.456] iteration 63100 [554.90 sec]: learning rate : 0.000013 loss : 0.359185 +[22:21:00.334] iteration 63200 [645.77 sec]: learning rate : 0.000013 loss : 0.533646 +[22:22:31.257] iteration 63300 [736.70 sec]: learning rate : 0.000013 loss : 0.261827 +[22:24:02.169] iteration 63400 [827.61 sec]: learning rate : 0.000013 loss : 0.483308 +[22:25:33.098] iteration 63500 [918.54 sec]: learning rate : 0.000013 loss : 0.447326 +[22:27:04.057] iteration 63600 [1009.50 sec]: learning rate : 0.000013 loss : 0.516308 +[22:28:35.018] iteration 63700 [1100.46 sec]: learning rate : 0.000013 loss : 0.377712 +[22:30:05.915] iteration 63800 [1191.36 sec]: learning rate : 0.000013 loss : 0.343763 +[22:31:36.849] iteration 63900 [1282.29 sec]: learning rate : 0.000013 loss : 0.296132 +[22:33:07.808] iteration 64000 [1373.25 sec]: learning rate : 0.000013 loss : 0.270902 +[22:34:38.711] iteration 64100 [1464.15 sec]: learning rate : 0.000013 loss : 0.509959 +[22:36:09.664] iteration 64200 [1555.10 sec]: learning rate : 0.000013 loss : 0.684271 +[22:37:40.611] iteration 64300 [1646.05 sec]: learning rate : 0.000013 loss : 0.405074 +[22:39:11.485] iteration 64400 [1736.92 sec]: learning rate : 0.000013 loss : 0.290687 +[22:40:42.404] iteration 64500 [1827.84 sec]: learning rate : 0.000013 loss : 0.633318 +[22:41:48.729] Epoch 30 Evaluation: +[22:42:40.446] average MSE: 0.040326669812202454 average PSNR: 29.67795940075506 average SSIM: 0.7338952169185964 +[22:43:05.260] iteration 64600 [24.75 sec]: learning rate : 0.000013 loss : 0.435577 +[22:44:36.255] iteration 64700 [115.74 sec]: learning rate : 0.000013 loss : 0.565983 +[22:46:07.193] iteration 64800 [206.68 sec]: learning rate : 0.000013 loss : 0.475459 +[22:47:38.090] iteration 64900 [297.58 sec]: learning rate : 0.000013 loss : 1.029810 +[22:49:09.044] iteration 65000 [388.53 sec]: learning rate : 0.000013 loss : 0.484401 +[22:50:40.019] iteration 65100 [479.51 sec]: learning rate : 0.000013 loss : 0.453941 +[22:52:10.906] iteration 65200 [570.39 sec]: learning rate : 0.000013 loss : 0.350184 +[22:53:41.871] iteration 65300 [661.36 sec]: learning rate : 0.000013 loss : 0.475254 +[22:55:12.760] iteration 65400 [752.25 sec]: learning rate : 0.000013 loss : 0.469455 +[22:56:43.689] iteration 65500 [843.18 sec]: learning rate : 0.000013 loss : 0.484791 +[22:58:14.625] iteration 65600 [934.11 sec]: learning rate : 0.000013 loss : 0.660436 +[22:59:45.512] iteration 65700 [1025.00 sec]: learning rate : 0.000013 loss : 0.534323 +[23:01:16.445] iteration 65800 [1115.93 sec]: learning rate : 0.000013 loss : 0.386983 +[23:02:47.372] iteration 65900 [1206.86 sec]: learning rate : 0.000013 loss : 0.447270 +[23:04:18.255] iteration 66000 [1297.74 sec]: learning rate : 0.000013 loss : 0.668911 +[23:05:49.184] iteration 66100 [1388.67 sec]: learning rate : 0.000013 loss : 0.468182 +[23:07:20.068] iteration 66200 [1479.56 sec]: learning rate : 0.000013 loss : 0.651088 +[23:08:50.994] iteration 66300 [1570.48 sec]: learning rate : 0.000013 loss : 0.560877 +[23:10:21.958] iteration 66400 [1661.45 sec]: learning rate : 0.000013 loss : 0.456627 +[23:11:52.833] iteration 66500 [1752.32 sec]: learning rate : 0.000013 loss : 0.490316 +[23:13:23.755] iteration 66600 [1843.24 sec]: learning rate : 0.000013 loss : 0.537827 +[23:14:14.608] Epoch 31 Evaluation: +[23:15:05.111] average MSE: 0.04023518040776253 average PSNR: 29.68585020693114 average SSIM: 0.7341906972551063 +[23:15:45.464] iteration 66700 [40.29 sec]: learning rate : 0.000013 loss : 0.623822 +[23:17:16.340] iteration 66800 [131.16 sec]: learning rate : 0.000013 loss : 0.512169 +[23:18:47.259] iteration 66900 [222.08 sec]: learning rate : 0.000013 loss : 0.502655 +[23:20:18.139] iteration 67000 [312.96 sec]: learning rate : 0.000013 loss : 0.308325 +[23:21:49.074] iteration 67100 [403.90 sec]: learning rate : 0.000013 loss : 0.568212 +[23:23:19.997] iteration 67200 [494.82 sec]: learning rate : 0.000013 loss : 0.287001 +[23:24:50.889] iteration 67300 [585.71 sec]: learning rate : 0.000013 loss : 0.533028 +[23:26:21.813] iteration 67400 [676.64 sec]: learning rate : 0.000013 loss : 0.457780 +[23:27:52.776] iteration 67500 [767.60 sec]: learning rate : 0.000013 loss : 0.550413 +[23:29:23.671] iteration 67600 [858.49 sec]: learning rate : 0.000013 loss : 0.374686 +[23:30:54.595] iteration 67700 [949.42 sec]: learning rate : 0.000013 loss : 0.453935 +[23:32:25.556] iteration 67800 [1040.38 sec]: learning rate : 0.000013 loss : 0.303537 +[23:33:56.466] iteration 67900 [1131.29 sec]: learning rate : 0.000013 loss : 0.244040 +[23:35:27.444] iteration 68000 [1222.27 sec]: learning rate : 0.000013 loss : 0.518377 +[23:36:58.342] iteration 68100 [1313.16 sec]: learning rate : 0.000013 loss : 0.256429 +[23:38:29.310] iteration 68200 [1404.13 sec]: learning rate : 0.000013 loss : 0.456836 +[23:40:00.258] iteration 68300 [1495.08 sec]: learning rate : 0.000013 loss : 0.528512 +[23:41:31.165] iteration 68400 [1585.99 sec]: learning rate : 0.000013 loss : 0.710422 +[23:43:02.107] iteration 68500 [1676.93 sec]: learning rate : 0.000013 loss : 0.800440 +[23:44:32.999] iteration 68600 [1767.82 sec]: learning rate : 0.000013 loss : 0.511209 +[23:46:03.868] iteration 68700 [1858.69 sec]: learning rate : 0.000013 loss : 0.317233 +[23:46:39.290] Epoch 32 Evaluation: +[23:47:29.892] average MSE: 0.04023408517241478 average PSNR: 29.679834962282865 average SSIM: 0.7343457959898535 +[23:48:25.643] iteration 68800 [55.68 sec]: learning rate : 0.000013 loss : 0.293694 +[23:49:56.521] iteration 68900 [146.56 sec]: learning rate : 0.000013 loss : 0.462454 +[23:51:27.443] iteration 69000 [237.49 sec]: learning rate : 0.000013 loss : 0.461217 +[23:52:58.369] iteration 69100 [328.41 sec]: learning rate : 0.000013 loss : 0.622857 +[23:54:29.245] iteration 69200 [419.29 sec]: learning rate : 0.000013 loss : 0.309781 +[23:56:00.180] iteration 69300 [510.22 sec]: learning rate : 0.000013 loss : 0.513874 +[23:57:31.076] iteration 69400 [601.12 sec]: learning rate : 0.000013 loss : 0.508360 +[23:59:01.970] iteration 69500 [692.01 sec]: learning rate : 0.000013 loss : 0.381958 +[00:00:32.902] iteration 69600 [782.94 sec]: learning rate : 0.000013 loss : 0.469921 +[00:02:03.787] iteration 69700 [873.83 sec]: learning rate : 0.000013 loss : 0.476291 +[00:03:34.729] iteration 69800 [964.77 sec]: learning rate : 0.000013 loss : 0.555915 +[00:05:05.673] iteration 69900 [1055.71 sec]: learning rate : 0.000013 loss : 0.357437 +[00:06:36.597] iteration 70000 [1146.64 sec]: learning rate : 0.000013 loss : 0.480997 +[00:08:07.554] iteration 70100 [1237.60 sec]: learning rate : 0.000013 loss : 0.571773 +[00:09:38.532] iteration 70200 [1328.57 sec]: learning rate : 0.000013 loss : 0.334984 +[00:11:09.453] iteration 70300 [1419.50 sec]: learning rate : 0.000013 loss : 0.386419 +[00:12:40.387] iteration 70400 [1510.43 sec]: learning rate : 0.000013 loss : 0.339039 +[00:14:11.282] iteration 70500 [1601.32 sec]: learning rate : 0.000013 loss : 0.452972 +[00:15:42.234] iteration 70600 [1692.28 sec]: learning rate : 0.000013 loss : 0.398019 +[00:17:13.175] iteration 70700 [1783.22 sec]: learning rate : 0.000013 loss : 0.534020 +[00:18:44.068] iteration 70800 [1874.11 sec]: learning rate : 0.000013 loss : 0.806515 +[00:19:04.037] Epoch 33 Evaluation: +[00:19:54.685] average MSE: 0.03990236297249794 average PSNR: 29.719041532456497 average SSIM: 0.7346594767830833 +[00:21:05.942] iteration 70900 [71.19 sec]: learning rate : 0.000013 loss : 0.448176 +[00:22:36.884] iteration 71000 [162.13 sec]: learning rate : 0.000013 loss : 0.369597 +[00:24:07.797] iteration 71100 [253.05 sec]: learning rate : 0.000013 loss : 0.354998 +[00:25:38.778] iteration 71200 [344.03 sec]: learning rate : 0.000013 loss : 0.561917 +[00:27:09.691] iteration 71300 [434.94 sec]: learning rate : 0.000013 loss : 0.333606 +[00:28:40.660] iteration 71400 [525.91 sec]: learning rate : 0.000013 loss : 0.389376 +[00:30:11.622] iteration 71500 [616.87 sec]: learning rate : 0.000013 loss : 0.425004 +[00:31:42.515] iteration 71600 [707.76 sec]: learning rate : 0.000013 loss : 0.500680 +[00:33:13.462] iteration 71700 [798.71 sec]: learning rate : 0.000013 loss : 0.829903 +[00:34:44.409] iteration 71800 [889.66 sec]: learning rate : 0.000013 loss : 0.645884 +[00:36:15.305] iteration 71900 [980.55 sec]: learning rate : 0.000013 loss : 0.457804 +[00:37:46.240] iteration 72000 [1071.49 sec]: learning rate : 0.000013 loss : 0.554923 +[00:39:17.176] iteration 72100 [1162.43 sec]: learning rate : 0.000013 loss : 0.459876 +[00:40:48.084] iteration 72200 [1253.33 sec]: learning rate : 0.000013 loss : 0.741084 +[00:42:19.062] iteration 72300 [1344.31 sec]: learning rate : 0.000013 loss : 0.584176 +[00:43:49.954] iteration 72400 [1435.20 sec]: learning rate : 0.000013 loss : 0.481339 +[00:45:20.899] iteration 72500 [1526.15 sec]: learning rate : 0.000013 loss : 0.633286 +[00:46:51.841] iteration 72600 [1617.09 sec]: learning rate : 0.000013 loss : 0.409780 +[00:48:22.732] iteration 72700 [1707.98 sec]: learning rate : 0.000013 loss : 0.497855 +[00:49:53.683] iteration 72800 [1798.93 sec]: learning rate : 0.000013 loss : 0.466317 +[00:51:24.625] iteration 72900 [1889.87 sec]: learning rate : 0.000013 loss : 0.330856 +[00:51:29.137] Epoch 34 Evaluation: +[00:52:19.796] average MSE: 0.03997483476996422 average PSNR: 29.71694627364347 average SSIM: 0.7346457487812943 +[00:53:46.416] iteration 73000 [86.55 sec]: learning rate : 0.000013 loss : 0.405597 +[00:55:17.377] iteration 73100 [177.52 sec]: learning rate : 0.000013 loss : 0.337389 +[00:56:48.261] iteration 73200 [268.40 sec]: learning rate : 0.000013 loss : 0.332322 +[00:58:19.188] iteration 73300 [359.32 sec]: learning rate : 0.000013 loss : 0.448716 +[00:59:50.115] iteration 73400 [450.25 sec]: learning rate : 0.000013 loss : 0.413343 +[01:01:20.985] iteration 73500 [541.12 sec]: learning rate : 0.000013 loss : 0.395122 +[01:02:51.911] iteration 73600 [632.05 sec]: learning rate : 0.000013 loss : 0.596743 +[01:04:22.840] iteration 73700 [722.98 sec]: learning rate : 0.000013 loss : 0.471195 +[01:05:53.722] iteration 73800 [813.86 sec]: learning rate : 0.000013 loss : 0.273450 +[01:07:24.665] iteration 73900 [904.80 sec]: learning rate : 0.000013 loss : 0.486226 +[01:08:55.561] iteration 74000 [995.70 sec]: learning rate : 0.000013 loss : 0.334327 +[01:10:26.506] iteration 74100 [1086.64 sec]: learning rate : 0.000013 loss : 0.475080 +[01:11:57.395] iteration 74200 [1177.53 sec]: learning rate : 0.000013 loss : 0.596083 +[01:13:28.296] iteration 74300 [1268.43 sec]: learning rate : 0.000013 loss : 0.425649 +[01:14:59.254] iteration 74400 [1359.39 sec]: learning rate : 0.000013 loss : 0.476550 +[01:16:30.222] iteration 74500 [1450.36 sec]: learning rate : 0.000013 loss : 0.405288 +[01:18:01.108] iteration 74600 [1541.25 sec]: learning rate : 0.000013 loss : 0.598043 +[01:19:32.000] iteration 74700 [1632.14 sec]: learning rate : 0.000013 loss : 0.745339 +[01:21:02.880] iteration 74800 [1723.02 sec]: learning rate : 0.000013 loss : 0.491817 +[01:22:33.797] iteration 74900 [1813.93 sec]: learning rate : 0.000013 loss : 0.335114 +[01:23:53.791] Epoch 35 Evaluation: +[01:24:44.486] average MSE: 0.03961585462093353 average PSNR: 29.746739211547606 average SSIM: 0.7350601165065878 +[01:24:55.657] iteration 75000 [11.10 sec]: learning rate : 0.000013 loss : 0.380137 +[01:26:26.554] iteration 75100 [102.00 sec]: learning rate : 0.000013 loss : 0.818587 +[01:27:57.555] iteration 75200 [193.00 sec]: learning rate : 0.000013 loss : 0.307769 +[01:29:28.504] iteration 75300 [283.95 sec]: learning rate : 0.000013 loss : 0.377879 +[01:30:59.423] iteration 75400 [374.87 sec]: learning rate : 0.000013 loss : 0.443566 +[01:32:30.417] iteration 75500 [465.86 sec]: learning rate : 0.000013 loss : 0.375643 +[01:34:01.381] iteration 75600 [556.83 sec]: learning rate : 0.000013 loss : 0.416548 +[01:35:32.300] iteration 75700 [647.75 sec]: learning rate : 0.000013 loss : 0.382489 +[01:37:03.265] iteration 75800 [738.71 sec]: learning rate : 0.000013 loss : 0.350671 +[01:38:34.198] iteration 75900 [829.65 sec]: learning rate : 0.000013 loss : 0.333153 +[01:40:05.116] iteration 76000 [920.58 sec]: learning rate : 0.000013 loss : 0.388429 +[01:41:36.099] iteration 76100 [1011.55 sec]: learning rate : 0.000013 loss : 0.411771 +[01:43:07.067] iteration 76200 [1102.51 sec]: learning rate : 0.000013 loss : 0.312751 +[01:44:37.984] iteration 76300 [1193.43 sec]: learning rate : 0.000013 loss : 0.319896 +[01:46:08.959] iteration 76400 [1284.41 sec]: learning rate : 0.000013 loss : 0.461623 +[01:47:39.861] iteration 76500 [1375.31 sec]: learning rate : 0.000013 loss : 0.314246 +[01:49:10.819] iteration 76600 [1466.27 sec]: learning rate : 0.000013 loss : 0.314194 +[01:50:41.759] iteration 76700 [1557.21 sec]: learning rate : 0.000013 loss : 0.525828 +[01:52:12.660] iteration 76800 [1648.11 sec]: learning rate : 0.000013 loss : 0.499061 +[01:53:43.614] iteration 76900 [1739.06 sec]: learning rate : 0.000013 loss : 0.349906 +[01:55:14.559] iteration 77000 [1830.01 sec]: learning rate : 0.000013 loss : 0.466436 +[01:56:19.060] Epoch 36 Evaluation: +[01:57:09.611] average MSE: 0.03989878669381142 average PSNR: 29.717812174688245 average SSIM: 0.7350136523962864 +[01:57:36.214] iteration 77100 [26.54 sec]: learning rate : 0.000013 loss : 0.480158 +[01:59:07.198] iteration 77200 [117.52 sec]: learning rate : 0.000013 loss : 0.447588 +[02:00:38.092] iteration 77300 [208.41 sec]: learning rate : 0.000013 loss : 0.382715 +[02:02:09.031] iteration 77400 [299.35 sec]: learning rate : 0.000013 loss : 0.399888 +[02:03:39.976] iteration 77500 [390.30 sec]: learning rate : 0.000013 loss : 0.660644 +[02:05:10.876] iteration 77600 [481.20 sec]: learning rate : 0.000013 loss : 0.598735 +[02:06:41.854] iteration 77700 [572.18 sec]: learning rate : 0.000013 loss : 0.905191 +[02:08:12.823] iteration 77800 [663.15 sec]: learning rate : 0.000013 loss : 0.829938 +[02:09:43.717] iteration 77900 [754.04 sec]: learning rate : 0.000013 loss : 0.296336 +[02:11:14.670] iteration 78000 [844.99 sec]: learning rate : 0.000013 loss : 0.523455 +[02:12:45.617] iteration 78100 [935.94 sec]: learning rate : 0.000013 loss : 0.342781 +[02:14:16.543] iteration 78200 [1026.87 sec]: learning rate : 0.000013 loss : 0.420253 +[02:15:47.488] iteration 78300 [1117.81 sec]: learning rate : 0.000013 loss : 0.443923 +[02:17:18.376] iteration 78400 [1208.70 sec]: learning rate : 0.000013 loss : 0.347380 +[02:18:49.329] iteration 78500 [1299.65 sec]: learning rate : 0.000013 loss : 0.475703 +[02:20:20.291] iteration 78600 [1390.61 sec]: learning rate : 0.000013 loss : 0.471666 +[02:21:51.192] iteration 78700 [1481.52 sec]: learning rate : 0.000013 loss : 0.576212 +[02:23:22.141] iteration 78800 [1572.46 sec]: learning rate : 0.000013 loss : 0.438377 +[02:24:53.105] iteration 78900 [1663.43 sec]: learning rate : 0.000013 loss : 0.395529 +[02:26:24.001] iteration 79000 [1754.32 sec]: learning rate : 0.000013 loss : 0.677623 +[02:27:54.925] iteration 79100 [1845.25 sec]: learning rate : 0.000013 loss : 0.316688 +[02:28:43.973] Epoch 37 Evaluation: +[02:29:34.611] average MSE: 0.040141645818948746 average PSNR: 29.701255109672033 average SSIM: 0.7348380294414762 +[02:30:16.772] iteration 79200 [42.09 sec]: learning rate : 0.000013 loss : 0.315387 +[02:31:47.646] iteration 79300 [132.97 sec]: learning rate : 0.000013 loss : 0.338147 +[02:33:18.571] iteration 79400 [223.89 sec]: learning rate : 0.000013 loss : 0.257938 +[02:34:49.524] iteration 79500 [314.85 sec]: learning rate : 0.000013 loss : 0.534682 +[02:36:20.422] iteration 79600 [405.74 sec]: learning rate : 0.000013 loss : 0.366485 +[02:37:51.350] iteration 79700 [496.67 sec]: learning rate : 0.000013 loss : 0.678546 +[02:39:22.223] iteration 79800 [587.55 sec]: learning rate : 0.000013 loss : 0.497850 +[02:40:53.173] iteration 79900 [678.50 sec]: learning rate : 0.000013 loss : 0.423219 +[02:42:24.114] iteration 80000 [769.44 sec]: learning rate : 0.000003 loss : 0.413583 +[02:42:24.270] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_80000.pth +[02:43:55.147] iteration 80100 [860.47 sec]: learning rate : 0.000006 loss : 0.731038 +[02:45:26.078] iteration 80200 [951.40 sec]: learning rate : 0.000006 loss : 0.401524 +[02:46:57.035] iteration 80300 [1042.36 sec]: learning rate : 0.000006 loss : 0.637932 +[02:48:27.933] iteration 80400 [1133.26 sec]: learning rate : 0.000006 loss : 1.050682 +[02:49:58.867] iteration 80500 [1224.19 sec]: learning rate : 0.000006 loss : 0.293156 +[02:51:29.764] iteration 80600 [1315.09 sec]: learning rate : 0.000006 loss : 0.320928 +[02:53:00.721] iteration 80700 [1406.04 sec]: learning rate : 0.000006 loss : 0.515221 +[02:54:31.677] iteration 80800 [1497.00 sec]: learning rate : 0.000006 loss : 0.758250 +[02:56:02.559] iteration 80900 [1587.88 sec]: learning rate : 0.000006 loss : 0.388341 +[02:57:33.488] iteration 81000 [1678.81 sec]: learning rate : 0.000006 loss : 0.502542 +[02:59:04.445] iteration 81100 [1769.77 sec]: learning rate : 0.000006 loss : 0.646105 +[03:00:35.345] iteration 81200 [1860.67 sec]: learning rate : 0.000006 loss : 0.331240 +[03:01:09.005] Epoch 38 Evaluation: +[03:02:00.965] average MSE: 0.03965437412261963 average PSNR: 29.749585628280233 average SSIM: 0.7352594712236916 +[03:02:58.471] iteration 81300 [57.44 sec]: learning rate : 0.000006 loss : 0.236902 +[03:04:29.475] iteration 81400 [148.44 sec]: learning rate : 0.000006 loss : 0.422052 +[03:06:00.360] iteration 81500 [239.33 sec]: learning rate : 0.000006 loss : 0.515255 +[03:07:31.253] iteration 81600 [330.22 sec]: learning rate : 0.000006 loss : 0.514095 +[03:09:02.206] iteration 81700 [421.18 sec]: learning rate : 0.000006 loss : 0.499090 +[03:10:33.111] iteration 81800 [512.08 sec]: learning rate : 0.000006 loss : 0.554625 +[03:12:04.075] iteration 81900 [603.04 sec]: learning rate : 0.000006 loss : 0.901293 +[03:13:35.006] iteration 82000 [693.98 sec]: learning rate : 0.000006 loss : 0.338916 +[03:15:05.880] iteration 82100 [784.85 sec]: learning rate : 0.000006 loss : 0.570151 +[03:16:36.799] iteration 82200 [875.77 sec]: learning rate : 0.000006 loss : 0.245807 +[03:18:07.696] iteration 82300 [966.66 sec]: learning rate : 0.000006 loss : 0.467466 +[03:19:38.563] iteration 82400 [1057.53 sec]: learning rate : 0.000006 loss : 0.410591 +[03:21:09.483] iteration 82500 [1148.45 sec]: learning rate : 0.000006 loss : 0.386922 +[03:22:40.350] iteration 82600 [1239.32 sec]: learning rate : 0.000006 loss : 0.354673 +[03:24:11.269] iteration 82700 [1330.24 sec]: learning rate : 0.000006 loss : 0.298772 +[03:25:42.199] iteration 82800 [1421.17 sec]: learning rate : 0.000006 loss : 0.421395 +[03:27:13.094] iteration 82900 [1512.06 sec]: learning rate : 0.000006 loss : 0.467583 +[03:28:44.037] iteration 83000 [1603.01 sec]: learning rate : 0.000006 loss : 0.440890 +[03:30:14.970] iteration 83100 [1693.94 sec]: learning rate : 0.000006 loss : 0.308879 +[03:31:45.848] iteration 83200 [1784.82 sec]: learning rate : 0.000006 loss : 0.635197 +[03:33:16.784] iteration 83300 [1875.75 sec]: learning rate : 0.000006 loss : 0.379215 +[03:33:34.936] Epoch 39 Evaluation: +[03:34:26.798] average MSE: 0.039899587631225586 average PSNR: 29.724261331087817 average SSIM: 0.7353158102302899 +[03:35:39.739] iteration 83400 [72.87 sec]: learning rate : 0.000006 loss : 0.358078 +[03:37:10.723] iteration 83500 [163.86 sec]: learning rate : 0.000006 loss : 0.520017 +[03:38:41.644] iteration 83600 [254.78 sec]: learning rate : 0.000006 loss : 0.489289 +[03:40:12.521] iteration 83700 [345.66 sec]: learning rate : 0.000006 loss : 0.424795 +[03:41:43.444] iteration 83800 [436.58 sec]: learning rate : 0.000006 loss : 0.532242 +[03:43:14.350] iteration 83900 [527.49 sec]: learning rate : 0.000006 loss : 0.563425 +[03:44:45.240] iteration 84000 [618.38 sec]: learning rate : 0.000006 loss : 0.440198 +[03:46:16.155] iteration 84100 [709.29 sec]: learning rate : 0.000006 loss : 0.721244 +[03:47:47.039] iteration 84200 [800.17 sec]: learning rate : 0.000006 loss : 0.665638 +[03:49:17.906] iteration 84300 [891.04 sec]: learning rate : 0.000006 loss : 1.020897 +[03:50:48.842] iteration 84400 [981.98 sec]: learning rate : 0.000006 loss : 0.316416 +[03:52:19.717] iteration 84500 [1072.85 sec]: learning rate : 0.000006 loss : 0.482595 +[03:53:50.642] iteration 84600 [1163.78 sec]: learning rate : 0.000006 loss : 0.457179 +[03:55:21.568] iteration 84700 [1254.70 sec]: learning rate : 0.000006 loss : 0.492077 +[03:56:52.436] iteration 84800 [1345.57 sec]: learning rate : 0.000006 loss : 0.537565 +[03:58:23.363] iteration 84900 [1436.50 sec]: learning rate : 0.000006 loss : 0.373885 +[03:59:54.306] iteration 85000 [1527.44 sec]: learning rate : 0.000006 loss : 0.548875 +[04:01:25.211] iteration 85100 [1618.35 sec]: learning rate : 0.000006 loss : 0.500583 +[04:02:56.136] iteration 85200 [1709.27 sec]: learning rate : 0.000006 loss : 0.445543 +[04:04:27.024] iteration 85300 [1800.16 sec]: learning rate : 0.000006 loss : 0.487929 +[04:05:57.913] iteration 85400 [1891.05 sec]: learning rate : 0.000006 loss : 0.399146 +[04:06:00.608] Epoch 40 Evaluation: +[04:06:51.220] average MSE: 0.039701323956251144 average PSNR: 29.746713839969022 average SSIM: 0.7354508383923761 +[04:08:19.686] iteration 85500 [88.40 sec]: learning rate : 0.000006 loss : 0.687539 +[04:09:50.562] iteration 85600 [179.27 sec]: learning rate : 0.000006 loss : 0.394097 +[04:11:21.480] iteration 85700 [270.19 sec]: learning rate : 0.000006 loss : 0.632886 +[04:12:52.411] iteration 85800 [361.12 sec]: learning rate : 0.000006 loss : 0.444699 +[04:14:23.286] iteration 85900 [452.00 sec]: learning rate : 0.000006 loss : 0.720585 +[04:15:54.237] iteration 86000 [542.95 sec]: learning rate : 0.000006 loss : 0.342239 +[04:17:25.159] iteration 86100 [633.87 sec]: learning rate : 0.000006 loss : 0.348174 +[04:18:56.038] iteration 86200 [724.75 sec]: learning rate : 0.000006 loss : 0.589353 +[04:20:26.961] iteration 86300 [815.67 sec]: learning rate : 0.000006 loss : 0.544907 +[04:21:57.840] iteration 86400 [906.55 sec]: learning rate : 0.000006 loss : 0.446403 +[04:23:28.788] iteration 86500 [997.50 sec]: learning rate : 0.000006 loss : 0.527030 +[04:24:59.708] iteration 86600 [1088.42 sec]: learning rate : 0.000006 loss : 0.592178 +[04:26:30.587] iteration 86700 [1179.30 sec]: learning rate : 0.000006 loss : 0.235294 +[04:28:01.511] iteration 86800 [1270.22 sec]: learning rate : 0.000006 loss : 0.520654 +[04:29:32.449] iteration 86900 [1361.16 sec]: learning rate : 0.000006 loss : 0.285544 +[04:31:03.332] iteration 87000 [1452.05 sec]: learning rate : 0.000006 loss : 0.418672 +[04:32:34.259] iteration 87100 [1542.97 sec]: learning rate : 0.000006 loss : 0.510922 +[04:34:05.195] iteration 87200 [1633.91 sec]: learning rate : 0.000006 loss : 0.361550 +[04:35:36.086] iteration 87300 [1724.80 sec]: learning rate : 0.000006 loss : 0.406086 +[04:37:07.012] iteration 87400 [1815.73 sec]: learning rate : 0.000006 loss : 0.541354 +[04:38:25.130] Epoch 41 Evaluation: +[04:39:15.941] average MSE: 0.03999185562133789 average PSNR: 29.717549683335903 average SSIM: 0.7351032288842041 +[04:39:28.926] iteration 87500 [12.92 sec]: learning rate : 0.000006 loss : 0.369912 +[04:40:59.922] iteration 87600 [103.91 sec]: learning rate : 0.000006 loss : 0.398935 +[04:42:30.863] iteration 87700 [194.85 sec]: learning rate : 0.000006 loss : 0.490323 +[04:44:01.758] iteration 87800 [285.75 sec]: learning rate : 0.000006 loss : 0.795707 +[04:45:32.737] iteration 87900 [376.73 sec]: learning rate : 0.000006 loss : 0.587287 +[04:47:03.693] iteration 88000 [467.69 sec]: learning rate : 0.000006 loss : 0.349947 +[04:48:34.593] iteration 88100 [558.59 sec]: learning rate : 0.000006 loss : 0.375721 +[04:50:05.533] iteration 88200 [649.52 sec]: learning rate : 0.000006 loss : 0.583773 +[04:51:36.426] iteration 88300 [740.42 sec]: learning rate : 0.000006 loss : 0.401172 +[04:53:07.382] iteration 88400 [831.37 sec]: learning rate : 0.000006 loss : 0.553198 +[04:54:38.343] iteration 88500 [922.33 sec]: learning rate : 0.000006 loss : 0.531683 +[04:56:09.238] iteration 88600 [1013.23 sec]: learning rate : 0.000006 loss : 0.442264 +[04:57:40.181] iteration 88700 [1104.17 sec]: learning rate : 0.000006 loss : 0.623831 +[04:59:11.145] iteration 88800 [1195.14 sec]: learning rate : 0.000006 loss : 0.518669 +[05:00:42.040] iteration 88900 [1286.03 sec]: learning rate : 0.000006 loss : 0.537102 +[05:02:12.974] iteration 89000 [1376.97 sec]: learning rate : 0.000006 loss : 0.303466 +[05:03:43.942] iteration 89100 [1467.93 sec]: learning rate : 0.000006 loss : 0.541519 +[05:05:14.847] iteration 89200 [1558.84 sec]: learning rate : 0.000006 loss : 0.215486 +[05:06:45.796] iteration 89300 [1649.79 sec]: learning rate : 0.000006 loss : 0.302666 +[05:08:16.713] iteration 89400 [1740.71 sec]: learning rate : 0.000006 loss : 0.551956 +[05:09:47.667] iteration 89500 [1831.66 sec]: learning rate : 0.000006 loss : 0.643311 +[05:10:50.356] Epoch 42 Evaluation: +[05:11:41.802] average MSE: 0.03977058455348015 average PSNR: 29.738973848362104 average SSIM: 0.7354007928724572 +[05:12:10.288] iteration 89600 [28.42 sec]: learning rate : 0.000006 loss : 0.452661 +[05:13:41.188] iteration 89700 [119.32 sec]: learning rate : 0.000006 loss : 0.345732 +[05:15:12.125] iteration 89800 [210.26 sec]: learning rate : 0.000006 loss : 0.417775 +[05:16:43.029] iteration 89900 [301.16 sec]: learning rate : 0.000006 loss : 0.362364 +[05:18:13.919] iteration 90000 [392.05 sec]: learning rate : 0.000006 loss : 0.635215 +[05:19:44.860] iteration 90100 [482.99 sec]: learning rate : 0.000006 loss : 0.444994 +[05:21:15.751] iteration 90200 [573.88 sec]: learning rate : 0.000006 loss : 0.716137 +[05:22:46.705] iteration 90300 [664.84 sec]: learning rate : 0.000006 loss : 0.718335 +[05:24:17.666] iteration 90400 [755.80 sec]: learning rate : 0.000006 loss : 0.482738 +[05:25:48.557] iteration 90500 [846.69 sec]: learning rate : 0.000006 loss : 0.455896 +[05:27:19.503] iteration 90600 [937.63 sec]: learning rate : 0.000006 loss : 0.577615 +[05:28:50.453] iteration 90700 [1028.58 sec]: learning rate : 0.000006 loss : 0.304052 +[05:30:21.352] iteration 90800 [1119.48 sec]: learning rate : 0.000006 loss : 0.488172 +[05:31:52.309] iteration 90900 [1210.44 sec]: learning rate : 0.000006 loss : 0.351912 +[05:33:23.254] iteration 91000 [1301.38 sec]: learning rate : 0.000006 loss : 0.631828 +[05:34:54.154] iteration 91100 [1392.28 sec]: learning rate : 0.000006 loss : 0.380733 +[05:36:25.094] iteration 91200 [1483.22 sec]: learning rate : 0.000006 loss : 0.578017 +[05:37:56.053] iteration 91300 [1574.18 sec]: learning rate : 0.000006 loss : 0.337623 +[05:39:26.964] iteration 91400 [1665.10 sec]: learning rate : 0.000006 loss : 0.397273 +[05:40:57.903] iteration 91500 [1756.03 sec]: learning rate : 0.000006 loss : 0.457402 +[05:42:28.793] iteration 91600 [1846.92 sec]: learning rate : 0.000006 loss : 0.269356 +[05:43:16.084] Epoch 43 Evaluation: +[05:44:06.710] average MSE: 0.040039125829935074 average PSNR: 29.7122653338449 average SSIM: 0.7349268707426715 +[05:44:50.598] iteration 91700 [43.82 sec]: learning rate : 0.000006 loss : 0.407613 +[05:46:21.545] iteration 91800 [134.77 sec]: learning rate : 0.000006 loss : 0.545782 +[05:47:52.427] iteration 91900 [225.65 sec]: learning rate : 0.000006 loss : 0.363113 +[05:49:23.352] iteration 92000 [316.58 sec]: learning rate : 0.000006 loss : 0.751430 +[05:50:54.248] iteration 92100 [407.47 sec]: learning rate : 0.000006 loss : 0.609690 +[05:52:25.126] iteration 92200 [498.35 sec]: learning rate : 0.000006 loss : 0.685917 +[05:53:56.047] iteration 92300 [589.27 sec]: learning rate : 0.000006 loss : 0.760396 +[05:55:26.974] iteration 92400 [680.20 sec]: learning rate : 0.000006 loss : 0.495093 +[05:56:57.849] iteration 92500 [771.07 sec]: learning rate : 0.000006 loss : 0.403633 +[05:58:28.800] iteration 92600 [862.02 sec]: learning rate : 0.000006 loss : 0.604688 +[05:59:59.701] iteration 92700 [952.92 sec]: learning rate : 0.000006 loss : 0.332899 +[06:01:30.577] iteration 92800 [1043.80 sec]: learning rate : 0.000006 loss : 0.681375 +[06:03:01.494] iteration 92900 [1134.72 sec]: learning rate : 0.000006 loss : 0.415556 +[06:04:32.431] iteration 93000 [1225.65 sec]: learning rate : 0.000006 loss : 0.308055 +[06:06:03.320] iteration 93100 [1316.54 sec]: learning rate : 0.000006 loss : 0.532678 +[06:07:34.267] iteration 93200 [1407.49 sec]: learning rate : 0.000006 loss : 0.637684 +[06:09:05.203] iteration 93300 [1498.49 sec]: learning rate : 0.000006 loss : 0.532963 +[06:10:36.109] iteration 93400 [1589.33 sec]: learning rate : 0.000006 loss : 0.507914 +[06:12:07.043] iteration 93500 [1680.27 sec]: learning rate : 0.000006 loss : 0.583308 +[06:13:37.912] iteration 93600 [1771.14 sec]: learning rate : 0.000006 loss : 0.350984 +[06:15:08.809] iteration 93700 [1862.03 sec]: learning rate : 0.000006 loss : 0.365254 +[06:15:40.585] Epoch 44 Evaluation: +[06:16:31.149] average MSE: 0.03957636281847954 average PSNR: 29.762342685944493 average SSIM: 0.7357898659139311 +[06:17:30.599] iteration 93800 [59.38 sec]: learning rate : 0.000006 loss : 0.349577 +[06:19:01.476] iteration 93900 [150.26 sec]: learning rate : 0.000006 loss : 0.261287 +[06:20:32.378] iteration 94000 [241.16 sec]: learning rate : 0.000006 loss : 0.331374 +[06:22:03.318] iteration 94100 [332.10 sec]: learning rate : 0.000006 loss : 0.559623 +[06:23:34.199] iteration 94200 [422.98 sec]: learning rate : 0.000006 loss : 0.375458 +[06:25:05.136] iteration 94300 [513.92 sec]: learning rate : 0.000006 loss : 0.450888 +[06:26:36.058] iteration 94400 [604.84 sec]: learning rate : 0.000006 loss : 0.378683 +[06:28:06.954] iteration 94500 [695.74 sec]: learning rate : 0.000006 loss : 0.407780 +[06:29:37.883] iteration 94600 [786.67 sec]: learning rate : 0.000006 loss : 0.313103 +[06:31:08.764] iteration 94700 [877.55 sec]: learning rate : 0.000006 loss : 0.469726 +[06:32:39.704] iteration 94800 [968.49 sec]: learning rate : 0.000006 loss : 0.634117 +[06:34:10.636] iteration 94900 [1059.42 sec]: learning rate : 0.000006 loss : 0.471475 +[06:35:41.520] iteration 95000 [1150.31 sec]: learning rate : 0.000006 loss : 0.718103 +[06:37:12.433] iteration 95100 [1241.22 sec]: learning rate : 0.000006 loss : 0.478889 +[06:38:43.378] iteration 95200 [1332.16 sec]: learning rate : 0.000006 loss : 0.411333 +[06:40:14.260] iteration 95300 [1423.04 sec]: learning rate : 0.000006 loss : 0.408598 +[06:41:45.193] iteration 95400 [1513.97 sec]: learning rate : 0.000006 loss : 0.283219 +[06:43:16.071] iteration 95500 [1604.85 sec]: learning rate : 0.000006 loss : 0.359699 +[06:44:46.986] iteration 95600 [1695.77 sec]: learning rate : 0.000006 loss : 0.696428 +[06:46:17.900] iteration 95700 [1786.68 sec]: learning rate : 0.000006 loss : 0.430501 +[06:47:48.778] iteration 95800 [1877.56 sec]: learning rate : 0.000006 loss : 0.807338 +[06:48:05.102] Epoch 45 Evaluation: +[06:48:55.573] average MSE: 0.039626020938158035 average PSNR: 29.75882496926239 average SSIM: 0.7359607550464302 +[06:50:10.492] iteration 95900 [74.85 sec]: learning rate : 0.000006 loss : 0.612016 +[06:51:41.395] iteration 96000 [165.75 sec]: learning rate : 0.000006 loss : 0.260208 +[06:53:12.341] iteration 96100 [256.70 sec]: learning rate : 0.000006 loss : 0.426758 +[06:54:43.251] iteration 96200 [347.61 sec]: learning rate : 0.000006 loss : 0.451288 +[06:56:14.152] iteration 96300 [438.51 sec]: learning rate : 0.000006 loss : 0.651844 +[06:57:45.159] iteration 96400 [529.52 sec]: learning rate : 0.000006 loss : 0.258685 +[06:59:16.064] iteration 96500 [620.42 sec]: learning rate : 0.000006 loss : 0.270012 +[07:00:46.971] iteration 96600 [711.33 sec]: learning rate : 0.000006 loss : 0.725691 +[07:02:17.891] iteration 96700 [802.25 sec]: learning rate : 0.000006 loss : 0.654056 +[07:03:48.846] iteration 96800 [893.21 sec]: learning rate : 0.000006 loss : 0.325239 +[07:05:19.744] iteration 96900 [984.10 sec]: learning rate : 0.000006 loss : 0.437645 +[07:06:50.650] iteration 97000 [1075.01 sec]: learning rate : 0.000006 loss : 0.429330 +[07:08:21.549] iteration 97100 [1165.91 sec]: learning rate : 0.000006 loss : 0.701867 +[07:09:52.486] iteration 97200 [1256.85 sec]: learning rate : 0.000006 loss : 0.412283 +[07:11:23.436] iteration 97300 [1347.80 sec]: learning rate : 0.000006 loss : 0.269945 +[07:12:54.333] iteration 97400 [1438.69 sec]: learning rate : 0.000006 loss : 0.624059 +[07:14:25.277] iteration 97500 [1529.64 sec]: learning rate : 0.000006 loss : 0.479841 +[07:15:56.227] iteration 97600 [1620.59 sec]: learning rate : 0.000006 loss : 0.680135 +[07:17:27.125] iteration 97700 [1711.49 sec]: learning rate : 0.000006 loss : 0.373554 +[07:18:58.065] iteration 97800 [1802.43 sec]: learning rate : 0.000006 loss : 0.402058 +[07:20:29.001] iteration 97900 [1893.36 sec]: learning rate : 0.000006 loss : 0.457124 +[07:20:29.886] Epoch 46 Evaluation: +[07:21:20.789] average MSE: 0.039534296840429306 average PSNR: 29.764532104210435 average SSIM: 0.7359105638674738 +[07:22:51.020] iteration 98000 [90.16 sec]: learning rate : 0.000006 loss : 0.499310 +[07:24:22.004] iteration 98100 [181.15 sec]: learning rate : 0.000006 loss : 0.506981 +[07:25:52.927] iteration 98200 [272.07 sec]: learning rate : 0.000006 loss : 1.093530 +[07:27:23.810] iteration 98300 [362.95 sec]: learning rate : 0.000006 loss : 0.319812 +[07:28:54.756] iteration 98400 [453.90 sec]: learning rate : 0.000006 loss : 0.544729 +[07:30:25.633] iteration 98500 [544.78 sec]: learning rate : 0.000006 loss : 0.351525 +[07:31:56.552] iteration 98600 [635.69 sec]: learning rate : 0.000006 loss : 0.428450 +[07:33:27.488] iteration 98700 [726.63 sec]: learning rate : 0.000006 loss : 0.490764 +[07:34:58.388] iteration 98800 [817.53 sec]: learning rate : 0.000006 loss : 0.574718 +[07:36:29.307] iteration 98900 [908.45 sec]: learning rate : 0.000006 loss : 0.367181 +[07:38:00.238] iteration 99000 [999.38 sec]: learning rate : 0.000006 loss : 0.382892 +[07:39:31.117] iteration 99100 [1090.26 sec]: learning rate : 0.000006 loss : 0.705409 +[07:41:02.008] iteration 99200 [1181.15 sec]: learning rate : 0.000006 loss : 0.444080 +[07:42:32.876] iteration 99300 [1272.02 sec]: learning rate : 0.000006 loss : 0.420414 +[07:44:03.831] iteration 99400 [1362.97 sec]: learning rate : 0.000006 loss : 0.373555 +[07:45:34.733] iteration 99500 [1453.88 sec]: learning rate : 0.000006 loss : 0.472489 +[07:47:05.614] iteration 99600 [1544.76 sec]: learning rate : 0.000006 loss : 0.399361 +[07:48:36.555] iteration 99700 [1635.70 sec]: learning rate : 0.000006 loss : 0.435494 +[07:50:07.504] iteration 99800 [1726.71 sec]: learning rate : 0.000006 loss : 0.420722 +[07:51:38.390] iteration 99900 [1817.53 sec]: learning rate : 0.000006 loss : 0.454410 +[07:52:54.755] Epoch 47 Evaluation: +[07:53:46.376] average MSE: 0.0397125780582428 average PSNR: 29.75151548095596 average SSIM: 0.7358228072609542 +[07:54:01.181] iteration 100000 [14.74 sec]: learning rate : 0.000002 loss : 0.723175 +[07:54:01.337] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_100000.pth +[07:54:02.219] Epoch 48 Evaluation: +[07:54:52.650] average MSE: 0.0396689809858799 average PSNR: 29.7507126608584 average SSIM: 0.7357126064916403 +[07:54:52.911] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_100000.pth +===> Evaluate Metric <=== +Direct Results +------------------------------------ +NMSE: 4.0579 ± 0.7262 +PSNR: 28.6400 ± 1.5177 +SSIM: 0.6787 ± 0.0405 +------------------------------------ +===> Evaluate Metric <=== +Results +------------------------------------ +NMSE: 4.0220 ± 0.7149 +PSNR: 28.6776 ± 1.5230 +SSIM: 0.6916 ± 0.0361 +------------------------------------ + + +===> Evaluate Metric <=== +Direct Results +------------------------------------ +NMSE: 4.0579 ± 0.7262 +PSNR: 28.6400 ± 1.5177 +SSIM: 0.6787 ± 0.0405 +------------------------------------ +===> Evaluate Metric <=== +Results +------------------------------------ +NMSE: 4.0220 ± 0.7149 +PSNR: 28.6776 ± 1.5230 +SSIM: 0.6916 ± 0.0361 +------------------------------------ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_/log/events.out.tfevents.1752411435.GCRSANDBOX133 b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_/log/events.out.tfevents.1752411435.GCRSANDBOX133 new file mode 100644 index 0000000000000000000000000000000000000000..e3c627ca0cf78f3f546e0efeabc64ebc1dcdc39e --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_/log/events.out.tfevents.1752411435.GCRSANDBOX133 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55551f86ba418c09da02d514a3e506384357148afaa5b798b19b5f41dcbcfa0e +size 40 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion/best_checkpoint.pth b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion/best_checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..cd57c32970bb7f0988927db4c9b05289b89f7baf --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion/best_checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a400a50c5c66c958eb4b5224288037d5e4efd367e3b2ea99cd89f4e0507132d +size 56614874 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion/log.txt b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..3ce838a08e99a85e805abd44dbbaf2ab556bf8f1 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion/log.txt @@ -0,0 +1,1105 @@ +[23:37:17.408] Namespace(root_path='/home/v-qichen3/MRI_recon/data/fastmri', MRIDOWN='4X', low_field_SNR=0, phase='train', gpu='2', exp='FSMNet_fastmri_4x', max_iterations=100000, batch_size=4, base_lr=0.0001, seed=1337, resume=None, relation_consistency='False', clip_grad='True', norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, snapshot_path='None', rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='random', CENTER_FRACTIONS=[0.08], ACCELERATIONS=[4], num_timesteps=5, image_size=320, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[23:38:46.730] iteration 100 [86.70 sec]: learning rate : 0.000100 loss : 0.664336 +[23:40:12.538] iteration 200 [172.51 sec]: learning rate : 0.000100 loss : 0.673694 +[23:41:38.503] iteration 300 [258.47 sec]: learning rate : 0.000100 loss : 0.677507 +[23:43:04.549] iteration 400 [344.52 sec]: learning rate : 0.000100 loss : 0.600353 +[23:44:30.579] iteration 500 [430.55 sec]: learning rate : 0.000100 loss : 0.942315 +[23:45:56.772] iteration 600 [516.74 sec]: learning rate : 0.000100 loss : 0.754277 +[23:47:22.996] iteration 700 [602.96 sec]: learning rate : 0.000100 loss : 0.841191 +[23:48:49.170] iteration 800 [689.14 sec]: learning rate : 0.000100 loss : 1.091522 +[23:50:15.462] iteration 900 [775.43 sec]: learning rate : 0.000100 loss : 0.601826 +[23:51:41.741] iteration 1000 [861.71 sec]: learning rate : 0.000100 loss : 0.449127 +[23:53:08.008] iteration 1100 [947.98 sec]: learning rate : 0.000100 loss : 1.238433 +[23:54:34.263] iteration 1200 [1034.23 sec]: learning rate : 0.000100 loss : 0.651324 +[23:56:00.607] iteration 1300 [1120.57 sec]: learning rate : 0.000100 loss : 0.669230 +[23:57:26.896] iteration 1400 [1206.86 sec]: learning rate : 0.000100 loss : 1.253515 +[23:58:53.204] iteration 1500 [1293.17 sec]: learning rate : 0.000100 loss : 1.061922 +[00:00:19.454] iteration 1600 [1379.42 sec]: learning rate : 0.000100 loss : 0.563248 +[00:01:45.802] iteration 1700 [1465.77 sec]: learning rate : 0.000100 loss : 0.621403 +[00:03:12.172] iteration 1800 [1552.14 sec]: learning rate : 0.000100 loss : 0.999637 +[00:04:38.472] iteration 1900 [1638.44 sec]: learning rate : 0.000100 loss : 0.766830 +[00:06:04.846] iteration 2000 [1724.82 sec]: learning rate : 0.000100 loss : 0.869758 +[00:07:16.459] Epoch 0 Evaluation: +[00:08:07.700] average MSE: 0.052322760224342346 average PSNR: 28.13818310345704 average SSIM: 0.6836697626540628 +[00:08:22.699] iteration 2100 [14.94 sec]: learning rate : 0.000100 loss : 0.527529 +[00:09:48.959] iteration 2200 [101.20 sec]: learning rate : 0.000100 loss : 0.471525 +[00:11:15.314] iteration 2300 [187.55 sec]: learning rate : 0.000100 loss : 0.781051 +[00:12:41.679] iteration 2400 [273.92 sec]: learning rate : 0.000100 loss : 0.693218 +[00:14:08.076] iteration 2500 [360.31 sec]: learning rate : 0.000100 loss : 0.395070 +[00:15:34.544] iteration 2600 [446.78 sec]: learning rate : 0.000100 loss : 0.530971 +[00:17:00.926] iteration 2700 [533.16 sec]: learning rate : 0.000100 loss : 0.525543 +[00:18:27.378] iteration 2800 [619.62 sec]: learning rate : 0.000100 loss : 0.836508 +[00:19:53.836] iteration 2900 [706.07 sec]: learning rate : 0.000100 loss : 0.570257 +[00:21:20.253] iteration 3000 [792.49 sec]: learning rate : 0.000100 loss : 0.661190 +[00:22:46.738] iteration 3100 [878.98 sec]: learning rate : 0.000100 loss : 0.669504 +[00:24:13.210] iteration 3200 [965.45 sec]: learning rate : 0.000100 loss : 0.598372 +[00:25:39.604] iteration 3300 [1051.84 sec]: learning rate : 0.000100 loss : 0.414270 +[00:27:06.042] iteration 3400 [1138.28 sec]: learning rate : 0.000100 loss : 0.664440 +[00:28:32.490] iteration 3500 [1224.73 sec]: learning rate : 0.000100 loss : 0.497923 +[00:29:58.879] iteration 3600 [1311.12 sec]: learning rate : 0.000100 loss : 0.542773 +[00:31:25.356] iteration 3700 [1397.59 sec]: learning rate : 0.000100 loss : 0.521152 +[00:32:51.821] iteration 3800 [1484.06 sec]: learning rate : 0.000100 loss : 0.503739 +[00:34:18.206] iteration 3900 [1570.45 sec]: learning rate : 0.000100 loss : 0.648733 +[00:35:44.621] iteration 4000 [1656.86 sec]: learning rate : 0.000100 loss : 0.680642 +[00:37:10.996] iteration 4100 [1743.23 sec]: learning rate : 0.000100 loss : 0.564777 +[00:38:08.069] Epoch 1 Evaluation: +[00:38:59.321] average MSE: 0.052759021520614624 average PSNR: 28.168008618206713 average SSIM: 0.7015316835982738 +[00:39:28.906] iteration 4200 [29.52 sec]: learning rate : 0.000100 loss : 0.616342 +[00:40:55.333] iteration 4300 [115.95 sec]: learning rate : 0.000100 loss : 0.504045 +[00:42:21.738] iteration 4400 [202.35 sec]: learning rate : 0.000100 loss : 0.422425 +[00:43:48.038] iteration 4500 [288.65 sec]: learning rate : 0.000100 loss : 0.662416 +[00:45:14.404] iteration 4600 [375.02 sec]: learning rate : 0.000100 loss : 0.596451 +[00:46:40.701] iteration 4700 [461.32 sec]: learning rate : 0.000100 loss : 0.666779 +[00:48:07.067] iteration 4800 [547.68 sec]: learning rate : 0.000100 loss : 0.414417 +[00:49:33.417] iteration 4900 [634.03 sec]: learning rate : 0.000100 loss : 0.749834 +[00:50:59.722] iteration 5000 [720.34 sec]: learning rate : 0.000100 loss : 0.716690 +[00:52:26.076] iteration 5100 [806.69 sec]: learning rate : 0.000100 loss : 0.708606 +[00:53:52.417] iteration 5200 [893.03 sec]: learning rate : 0.000100 loss : 0.702893 +[00:55:18.702] iteration 5300 [979.32 sec]: learning rate : 0.000100 loss : 0.480220 +[00:56:45.039] iteration 5400 [1065.66 sec]: learning rate : 0.000100 loss : 0.789153 +[00:58:11.324] iteration 5500 [1151.94 sec]: learning rate : 0.000100 loss : 0.531845 +[00:59:37.575] iteration 5600 [1238.19 sec]: learning rate : 0.000100 loss : 0.620262 +[01:01:03.893] iteration 5700 [1324.51 sec]: learning rate : 0.000100 loss : 0.371379 +[01:02:30.228] iteration 5800 [1410.84 sec]: learning rate : 0.000100 loss : 0.433268 +[01:03:56.513] iteration 5900 [1497.13 sec]: learning rate : 0.000100 loss : 0.611900 +[01:05:22.835] iteration 6000 [1583.45 sec]: learning rate : 0.000100 loss : 0.741463 +[01:06:49.130] iteration 6100 [1669.75 sec]: learning rate : 0.000100 loss : 0.848679 +[01:08:15.367] iteration 6200 [1755.98 sec]: learning rate : 0.000100 loss : 0.608874 +[01:08:57.595] Epoch 2 Evaluation: +[01:09:47.427] average MSE: 0.04796710982918739 average PSNR: 28.60557655152969 average SSIM: 0.7062884497006925 +[01:10:31.766] iteration 6300 [44.28 sec]: learning rate : 0.000100 loss : 0.553436 +[01:11:57.955] iteration 6400 [130.47 sec]: learning rate : 0.000100 loss : 0.505112 +[01:13:24.203] iteration 6500 [216.71 sec]: learning rate : 0.000100 loss : 0.581555 +[01:14:50.447] iteration 6600 [302.96 sec]: learning rate : 0.000100 loss : 0.525442 +[01:16:16.638] iteration 6700 [389.15 sec]: learning rate : 0.000100 loss : 0.668216 +[01:17:42.892] iteration 6800 [475.40 sec]: learning rate : 0.000100 loss : 0.535616 +[01:19:09.088] iteration 6900 [561.60 sec]: learning rate : 0.000100 loss : 0.675732 +[01:20:35.366] iteration 7000 [647.88 sec]: learning rate : 0.000100 loss : 0.501462 +[01:22:01.654] iteration 7100 [734.16 sec]: learning rate : 0.000100 loss : 0.561687 +[01:23:27.857] iteration 7200 [820.37 sec]: learning rate : 0.000100 loss : 0.460082 +[01:24:54.142] iteration 7300 [906.65 sec]: learning rate : 0.000100 loss : 0.351227 +[01:26:20.405] iteration 7400 [992.91 sec]: learning rate : 0.000100 loss : 0.370678 +[01:27:46.602] iteration 7500 [1079.11 sec]: learning rate : 0.000100 loss : 0.651751 +[01:29:12.856] iteration 7600 [1165.37 sec]: learning rate : 0.000100 loss : 0.298766 +[01:30:39.044] iteration 7700 [1251.55 sec]: learning rate : 0.000100 loss : 0.552115 +[01:32:05.251] iteration 7800 [1337.76 sec]: learning rate : 0.000100 loss : 0.579343 +[01:33:31.485] iteration 7900 [1424.00 sec]: learning rate : 0.000100 loss : 0.679708 +[01:34:57.684] iteration 8000 [1510.19 sec]: learning rate : 0.000100 loss : 0.394292 +[01:36:23.913] iteration 8100 [1596.42 sec]: learning rate : 0.000100 loss : 0.325633 +[01:37:50.150] iteration 8200 [1682.66 sec]: learning rate : 0.000100 loss : 0.492707 +[01:39:16.320] iteration 8300 [1768.83 sec]: learning rate : 0.000100 loss : 0.294093 +[01:39:43.860] Epoch 3 Evaluation: +[01:40:33.161] average MSE: 0.045949067920446396 average PSNR: 28.839230143329008 average SSIM: 0.7118368691890277 +[01:41:32.102] iteration 8400 [58.88 sec]: learning rate : 0.000100 loss : 0.721214 +[01:42:58.199] iteration 8500 [144.97 sec]: learning rate : 0.000100 loss : 0.572434 +[01:44:24.417] iteration 8600 [231.19 sec]: learning rate : 0.000100 loss : 0.455191 +[01:45:50.571] iteration 8700 [317.35 sec]: learning rate : 0.000100 loss : 0.965830 +[01:47:16.694] iteration 8800 [403.47 sec]: learning rate : 0.000100 loss : 1.131175 +[01:48:42.924] iteration 8900 [489.70 sec]: learning rate : 0.000100 loss : 0.566403 +[01:50:09.128] iteration 9000 [575.90 sec]: learning rate : 0.000100 loss : 0.620789 +[01:51:35.278] iteration 9100 [662.05 sec]: learning rate : 0.000100 loss : 0.523786 +[01:53:01.466] iteration 9200 [748.24 sec]: learning rate : 0.000100 loss : 0.335563 +[01:54:27.701] iteration 9300 [834.48 sec]: learning rate : 0.000100 loss : 0.743833 +[01:55:53.879] iteration 9400 [920.65 sec]: learning rate : 0.000100 loss : 0.846852 +[01:57:20.117] iteration 9500 [1006.89 sec]: learning rate : 0.000100 loss : 0.465688 +[01:58:46.275] iteration 9600 [1093.05 sec]: learning rate : 0.000100 loss : 0.612938 +[02:00:12.493] iteration 9700 [1179.27 sec]: learning rate : 0.000100 loss : 0.537913 +[02:01:38.703] iteration 9800 [1265.48 sec]: learning rate : 0.000100 loss : 0.435062 +[02:03:04.874] iteration 9900 [1351.65 sec]: learning rate : 0.000100 loss : 0.306808 +[02:04:31.128] iteration 10000 [1437.90 sec]: learning rate : 0.000100 loss : 0.622640 +[02:05:57.344] iteration 10100 [1524.12 sec]: learning rate : 0.000100 loss : 0.502954 +[02:07:23.503] iteration 10200 [1610.28 sec]: learning rate : 0.000100 loss : 0.629411 +[02:08:49.703] iteration 10300 [1696.48 sec]: learning rate : 0.000100 loss : 0.538307 +[02:10:15.944] iteration 10400 [1782.72 sec]: learning rate : 0.000100 loss : 0.415289 +[02:10:28.841] Epoch 4 Evaluation: +[02:11:18.547] average MSE: 0.04444773122668266 average PSNR: 29.022143608163635 average SSIM: 0.7145598597284041 +[02:12:31.960] iteration 10500 [73.35 sec]: learning rate : 0.000100 loss : 0.420704 +[02:13:58.141] iteration 10600 [159.53 sec]: learning rate : 0.000100 loss : 0.530065 +[02:15:24.284] iteration 10700 [245.67 sec]: learning rate : 0.000100 loss : 0.483061 +[02:16:50.375] iteration 10800 [331.76 sec]: learning rate : 0.000100 loss : 0.420063 +[02:18:16.547] iteration 10900 [417.94 sec]: learning rate : 0.000100 loss : 0.492313 +[02:19:42.730] iteration 11000 [504.12 sec]: learning rate : 0.000100 loss : 0.497682 +[02:21:08.829] iteration 11100 [590.22 sec]: learning rate : 0.000100 loss : 0.342341 +[02:22:35.003] iteration 11200 [676.39 sec]: learning rate : 0.000100 loss : 0.590305 +[02:24:01.182] iteration 11300 [762.57 sec]: learning rate : 0.000100 loss : 0.367453 +[02:25:27.285] iteration 11400 [848.67 sec]: learning rate : 0.000100 loss : 0.622613 +[02:26:53.462] iteration 11500 [934.85 sec]: learning rate : 0.000100 loss : 0.530422 +[02:28:19.631] iteration 11600 [1021.02 sec]: learning rate : 0.000100 loss : 0.561444 +[02:29:45.763] iteration 11700 [1107.15 sec]: learning rate : 0.000100 loss : 0.507816 +[02:31:11.970] iteration 11800 [1193.36 sec]: learning rate : 0.000100 loss : 0.912397 +[02:32:38.190] iteration 11900 [1279.58 sec]: learning rate : 0.000100 loss : 0.779102 +[02:34:04.330] iteration 12000 [1365.72 sec]: learning rate : 0.000100 loss : 0.306851 +[02:35:30.554] iteration 12100 [1451.94 sec]: learning rate : 0.000100 loss : 0.489913 +[02:36:56.695] iteration 12200 [1538.09 sec]: learning rate : 0.000100 loss : 0.590093 +[02:38:22.838] iteration 12300 [1624.23 sec]: learning rate : 0.000100 loss : 0.595138 +[02:39:49.069] iteration 12400 [1710.46 sec]: learning rate : 0.000100 loss : 0.676882 +[02:41:13.468] Epoch 5 Evaluation: +[02:42:03.916] average MSE: 0.04352114722132683 average PSNR: 29.12599841711727 average SSIM: 0.7168938053194104 +[02:42:05.885] iteration 12500 [1.91 sec]: learning rate : 0.000100 loss : 0.408864 +[02:43:32.194] iteration 12600 [88.22 sec]: learning rate : 0.000100 loss : 0.353820 +[02:44:58.478] iteration 12700 [174.50 sec]: learning rate : 0.000100 loss : 0.545949 +[02:46:24.718] iteration 12800 [260.74 sec]: learning rate : 0.000100 loss : 0.464313 +[02:47:51.019] iteration 12900 [347.04 sec]: learning rate : 0.000100 loss : 0.550747 +[02:49:17.328] iteration 13000 [433.35 sec]: learning rate : 0.000100 loss : 0.638906 +[02:50:43.577] iteration 13100 [519.60 sec]: learning rate : 0.000100 loss : 0.451821 +[02:52:09.883] iteration 13200 [605.91 sec]: learning rate : 0.000100 loss : 0.521371 +[02:53:36.148] iteration 13300 [692.17 sec]: learning rate : 0.000100 loss : 0.465796 +[02:55:02.471] iteration 13400 [778.49 sec]: learning rate : 0.000100 loss : 0.518116 +[02:56:28.823] iteration 13500 [864.85 sec]: learning rate : 0.000100 loss : 0.429209 +[02:57:55.106] iteration 13600 [951.13 sec]: learning rate : 0.000100 loss : 0.668473 +[02:59:21.429] iteration 13700 [1037.45 sec]: learning rate : 0.000100 loss : 0.378101 +[03:00:47.766] iteration 13800 [1123.79 sec]: learning rate : 0.000100 loss : 0.464662 +[03:02:14.071] iteration 13900 [1210.09 sec]: learning rate : 0.000100 loss : 0.493137 +[03:03:40.412] iteration 14000 [1296.44 sec]: learning rate : 0.000100 loss : 0.689547 +[03:05:06.798] iteration 14100 [1382.82 sec]: learning rate : 0.000100 loss : 0.284103 +[03:06:33.123] iteration 14200 [1469.15 sec]: learning rate : 0.000100 loss : 0.359892 +[03:07:59.496] iteration 14300 [1555.52 sec]: learning rate : 0.000100 loss : 0.311354 +[03:09:25.819] iteration 14400 [1641.84 sec]: learning rate : 0.000100 loss : 0.449686 +[03:10:52.205] iteration 14500 [1728.23 sec]: learning rate : 0.000100 loss : 0.432108 +[03:12:02.144] Epoch 6 Evaluation: +[03:12:51.814] average MSE: 0.04328539967536926 average PSNR: 29.173886717827212 average SSIM: 0.719590656758911 +[03:13:08.468] iteration 14600 [16.59 sec]: learning rate : 0.000100 loss : 0.578056 +[03:14:34.737] iteration 14700 [102.86 sec]: learning rate : 0.000100 loss : 0.569212 +[03:16:01.102] iteration 14800 [189.23 sec]: learning rate : 0.000100 loss : 0.701353 +[03:17:27.458] iteration 14900 [275.58 sec]: learning rate : 0.000100 loss : 0.324140 +[03:18:53.754] iteration 15000 [361.88 sec]: learning rate : 0.000100 loss : 0.295046 +[03:20:20.143] iteration 15100 [448.27 sec]: learning rate : 0.000100 loss : 0.287883 +[03:21:46.518] iteration 15200 [534.64 sec]: learning rate : 0.000100 loss : 0.366019 +[03:23:12.848] iteration 15300 [620.97 sec]: learning rate : 0.000100 loss : 0.555081 +[03:24:39.210] iteration 15400 [707.33 sec]: learning rate : 0.000100 loss : 0.620935 +[03:26:05.619] iteration 15500 [793.74 sec]: learning rate : 0.000100 loss : 0.335966 +[03:27:31.953] iteration 15600 [880.08 sec]: learning rate : 0.000100 loss : 0.353579 +[03:28:58.330] iteration 15700 [966.45 sec]: learning rate : 0.000100 loss : 0.894829 +[03:30:24.723] iteration 15800 [1052.85 sec]: learning rate : 0.000100 loss : 0.294456 +[03:31:51.047] iteration 15900 [1139.17 sec]: learning rate : 0.000100 loss : 0.467329 +[03:33:17.427] iteration 16000 [1225.55 sec]: learning rate : 0.000100 loss : 0.687781 +[03:34:43.804] iteration 16100 [1311.93 sec]: learning rate : 0.000100 loss : 0.749439 +[03:36:10.123] iteration 16200 [1398.25 sec]: learning rate : 0.000100 loss : 0.376435 +[03:37:36.490] iteration 16300 [1484.61 sec]: learning rate : 0.000100 loss : 0.483239 +[03:39:02.811] iteration 16400 [1570.93 sec]: learning rate : 0.000100 loss : 0.383261 +[03:40:29.204] iteration 16500 [1657.33 sec]: learning rate : 0.000100 loss : 0.691282 +[03:41:55.602] iteration 16600 [1743.72 sec]: learning rate : 0.000100 loss : 0.458725 +[03:42:50.833] Epoch 7 Evaluation: +[03:43:40.621] average MSE: 0.043066829442977905 average PSNR: 29.21811297370013 average SSIM: 0.7203694299638541 +[03:44:11.926] iteration 16700 [31.24 sec]: learning rate : 0.000100 loss : 0.398560 +[03:45:38.346] iteration 16800 [117.66 sec]: learning rate : 0.000100 loss : 0.243379 +[03:47:04.676] iteration 16900 [203.99 sec]: learning rate : 0.000100 loss : 0.272782 +[03:48:30.988] iteration 17000 [290.30 sec]: learning rate : 0.000100 loss : 0.950679 +[03:49:57.370] iteration 17100 [376.69 sec]: learning rate : 0.000100 loss : 0.542483 +[03:51:23.718] iteration 17200 [463.03 sec]: learning rate : 0.000100 loss : 0.219986 +[03:52:50.088] iteration 17300 [549.41 sec]: learning rate : 0.000100 loss : 0.526054 +[03:54:16.538] iteration 17400 [635.86 sec]: learning rate : 0.000100 loss : 0.529036 +[03:55:42.831] iteration 17500 [722.15 sec]: learning rate : 0.000100 loss : 0.503499 +[03:57:09.208] iteration 17600 [808.52 sec]: learning rate : 0.000100 loss : 0.459672 +[03:58:35.541] iteration 17700 [894.86 sec]: learning rate : 0.000100 loss : 0.520226 +[04:00:01.869] iteration 17800 [981.18 sec]: learning rate : 0.000100 loss : 0.415475 +[04:01:28.227] iteration 17900 [1067.54 sec]: learning rate : 0.000100 loss : 0.679521 +[04:02:54.530] iteration 18000 [1153.85 sec]: learning rate : 0.000100 loss : 0.724895 +[04:04:20.915] iteration 18100 [1240.23 sec]: learning rate : 0.000100 loss : 0.611724 +[04:05:47.321] iteration 18200 [1326.64 sec]: learning rate : 0.000100 loss : 0.374283 +[04:07:13.644] iteration 18300 [1412.96 sec]: learning rate : 0.000100 loss : 0.403620 +[04:08:40.037] iteration 18400 [1499.35 sec]: learning rate : 0.000100 loss : 0.481739 +[04:10:06.348] iteration 18500 [1585.67 sec]: learning rate : 0.000100 loss : 0.594136 +[04:11:32.728] iteration 18600 [1672.05 sec]: learning rate : 0.000100 loss : 0.564335 +[04:12:59.125] iteration 18700 [1758.44 sec]: learning rate : 0.000100 loss : 0.771390 +[04:13:39.694] Epoch 8 Evaluation: +[04:14:32.135] average MSE: 0.04222748056054115 average PSNR: 29.321932821378272 average SSIM: 0.7234668554320511 +[04:15:18.062] iteration 18800 [45.87 sec]: learning rate : 0.000100 loss : 0.559783 +[04:16:44.417] iteration 18900 [132.22 sec]: learning rate : 0.000100 loss : 0.479078 +[04:18:10.730] iteration 19000 [218.53 sec]: learning rate : 0.000100 loss : 0.744898 +[04:19:36.975] iteration 19100 [304.78 sec]: learning rate : 0.000100 loss : 0.649605 +[04:21:03.273] iteration 19200 [391.08 sec]: learning rate : 0.000100 loss : 0.498046 +[04:22:29.597] iteration 19300 [477.40 sec]: learning rate : 0.000100 loss : 0.491904 +[04:23:55.874] iteration 19400 [563.68 sec]: learning rate : 0.000100 loss : 0.380394 +[04:25:22.175] iteration 19500 [649.98 sec]: learning rate : 0.000100 loss : 0.404777 +[04:26:48.499] iteration 19600 [736.30 sec]: learning rate : 0.000100 loss : 0.482228 +[04:28:14.741] iteration 19700 [822.54 sec]: learning rate : 0.000100 loss : 0.632927 +[04:29:41.084] iteration 19800 [908.89 sec]: learning rate : 0.000100 loss : 0.710430 +[04:31:07.374] iteration 19900 [995.18 sec]: learning rate : 0.000100 loss : 0.443171 +[04:32:33.627] iteration 20000 [1081.43 sec]: learning rate : 0.000025 loss : 0.598682 +[04:32:33.784] save model to model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion/iter_20000.pth +[04:34:00.074] iteration 20100 [1167.88 sec]: learning rate : 0.000050 loss : 0.256254 +[04:35:26.343] iteration 20200 [1254.15 sec]: learning rate : 0.000050 loss : 0.332421 +[04:36:52.697] iteration 20300 [1340.50 sec]: learning rate : 0.000050 loss : 0.837223 +[04:38:19.055] iteration 20400 [1426.86 sec]: learning rate : 0.000050 loss : 0.736112 +[04:39:45.323] iteration 20500 [1513.13 sec]: learning rate : 0.000050 loss : 0.692908 +[04:41:11.631] iteration 20600 [1599.43 sec]: learning rate : 0.000050 loss : 0.383900 +[04:42:37.943] iteration 20700 [1685.74 sec]: learning rate : 0.000050 loss : 0.489619 +[04:44:04.200] iteration 20800 [1772.00 sec]: learning rate : 0.000050 loss : 0.353717 +[04:44:30.054] Epoch 9 Evaluation: +[04:45:21.678] average MSE: 0.04242523014545441 average PSNR: 29.2983404546992 average SSIM: 0.7228575728007648 +[04:46:22.396] iteration 20900 [60.66 sec]: learning rate : 0.000050 loss : 0.240118 +[04:47:48.636] iteration 21000 [146.90 sec]: learning rate : 0.000050 loss : 0.279137 +[04:49:14.924] iteration 21100 [233.18 sec]: learning rate : 0.000050 loss : 0.296159 +[04:50:41.196] iteration 21200 [319.46 sec]: learning rate : 0.000050 loss : 0.445494 +[04:52:07.450] iteration 21300 [405.71 sec]: learning rate : 0.000050 loss : 0.537314 +[04:53:33.760] iteration 21400 [492.02 sec]: learning rate : 0.000050 loss : 0.349848 +[04:55:00.062] iteration 21500 [578.38 sec]: learning rate : 0.000050 loss : 0.515165 +[04:56:26.296] iteration 21600 [664.56 sec]: learning rate : 0.000050 loss : 0.324323 +[04:57:52.598] iteration 21700 [750.86 sec]: learning rate : 0.000050 loss : 0.603855 +[04:59:18.836] iteration 21800 [837.09 sec]: learning rate : 0.000050 loss : 0.639075 +[05:00:45.151] iteration 21900 [923.41 sec]: learning rate : 0.000050 loss : 0.471756 +[05:02:11.441] iteration 22000 [1009.70 sec]: learning rate : 0.000050 loss : 0.512122 +[05:03:37.712] iteration 22100 [1095.97 sec]: learning rate : 0.000050 loss : 0.755381 +[05:05:03.970] iteration 22200 [1182.23 sec]: learning rate : 0.000050 loss : 0.445402 +[05:06:30.211] iteration 22300 [1268.47 sec]: learning rate : 0.000050 loss : 0.774442 +[05:07:56.487] iteration 22400 [1354.75 sec]: learning rate : 0.000050 loss : 0.402385 +[05:09:22.837] iteration 22500 [1441.10 sec]: learning rate : 0.000050 loss : 0.697577 +[05:10:49.091] iteration 22600 [1527.35 sec]: learning rate : 0.000050 loss : 0.581265 +[05:12:15.381] iteration 22700 [1613.64 sec]: learning rate : 0.000050 loss : 0.381248 +[05:13:41.716] iteration 22800 [1699.98 sec]: learning rate : 0.000050 loss : 0.464800 +[05:15:07.959] iteration 22900 [1786.22 sec]: learning rate : 0.000050 loss : 0.354382 +[05:15:19.139] Epoch 10 Evaluation: +[05:16:09.691] average MSE: 0.0422147773206234 average PSNR: 29.346057605191447 average SSIM: 0.7251580786554527 +[05:17:25.108] iteration 23000 [75.35 sec]: learning rate : 0.000050 loss : 0.249136 +[05:18:51.456] iteration 23100 [161.70 sec]: learning rate : 0.000050 loss : 0.488902 +[05:20:17.733] iteration 23200 [247.98 sec]: learning rate : 0.000050 loss : 0.267896 +[05:21:44.071] iteration 23300 [334.32 sec]: learning rate : 0.000050 loss : 0.476700 +[05:23:10.358] iteration 23400 [420.60 sec]: learning rate : 0.000050 loss : 0.620649 +[05:24:36.732] iteration 23500 [506.98 sec]: learning rate : 0.000050 loss : 0.554005 +[05:26:03.075] iteration 23600 [593.32 sec]: learning rate : 0.000050 loss : 0.413105 +[05:27:29.361] iteration 23700 [679.61 sec]: learning rate : 0.000050 loss : 0.367886 +[05:28:55.693] iteration 23800 [765.94 sec]: learning rate : 0.000050 loss : 0.294633 +[05:30:22.039] iteration 23900 [852.29 sec]: learning rate : 0.000050 loss : 0.499357 +[05:31:48.386] iteration 24000 [938.63 sec]: learning rate : 0.000050 loss : 0.330056 +[05:33:14.774] iteration 24100 [1025.02 sec]: learning rate : 0.000050 loss : 0.398210 +[05:34:41.161] iteration 24200 [1111.41 sec]: learning rate : 0.000050 loss : 0.348534 +[05:36:07.463] iteration 24300 [1197.71 sec]: learning rate : 0.000050 loss : 0.491726 +[05:37:33.818] iteration 24400 [1284.06 sec]: learning rate : 0.000050 loss : 0.468959 +[05:39:00.121] iteration 24500 [1370.37 sec]: learning rate : 0.000050 loss : 0.508388 +[05:40:26.492] iteration 24600 [1456.74 sec]: learning rate : 0.000050 loss : 0.391013 +[05:41:52.880] iteration 24700 [1543.13 sec]: learning rate : 0.000050 loss : 0.605703 +[05:43:19.203] iteration 24800 [1629.45 sec]: learning rate : 0.000050 loss : 0.414137 +[05:44:45.549] iteration 24900 [1715.79 sec]: learning rate : 0.000050 loss : 0.947000 +[05:46:08.371] Epoch 11 Evaluation: +[05:46:59.824] average MSE: 0.041780710220336914 average PSNR: 29.38239141195723 average SSIM: 0.7247598633516338 +[05:47:03.523] iteration 25000 [3.64 sec]: learning rate : 0.000050 loss : 0.452893 +[05:48:29.798] iteration 25100 [89.91 sec]: learning rate : 0.000050 loss : 0.461613 +[05:49:56.192] iteration 25200 [176.31 sec]: learning rate : 0.000050 loss : 0.362545 +[05:51:22.482] iteration 25300 [262.59 sec]: learning rate : 0.000050 loss : 1.043213 +[05:52:48.802] iteration 25400 [348.92 sec]: learning rate : 0.000050 loss : 0.387752 +[05:54:15.158] iteration 25500 [435.27 sec]: learning rate : 0.000050 loss : 0.502433 +[05:55:41.466] iteration 25600 [521.58 sec]: learning rate : 0.000050 loss : 0.333355 +[05:57:07.808] iteration 25700 [607.92 sec]: learning rate : 0.000050 loss : 0.504113 +[05:58:34.172] iteration 25800 [694.29 sec]: learning rate : 0.000050 loss : 0.525252 +[06:00:00.488] iteration 25900 [780.60 sec]: learning rate : 0.000050 loss : 0.387437 +[06:01:26.844] iteration 26000 [866.96 sec]: learning rate : 0.000050 loss : 0.444659 +[06:02:53.206] iteration 26100 [953.32 sec]: learning rate : 0.000050 loss : 0.583337 +[06:04:19.499] iteration 26200 [1039.61 sec]: learning rate : 0.000050 loss : 0.539472 +[06:05:45.902] iteration 26300 [1126.02 sec]: learning rate : 0.000050 loss : 0.391912 +[06:07:12.276] iteration 26400 [1212.39 sec]: learning rate : 0.000050 loss : 0.500449 +[06:08:38.588] iteration 26500 [1298.70 sec]: learning rate : 0.000050 loss : 0.316284 +[06:10:04.954] iteration 26600 [1385.07 sec]: learning rate : 0.000050 loss : 0.370891 +[06:11:31.331] iteration 26700 [1471.44 sec]: learning rate : 0.000050 loss : 0.363429 +[06:12:57.626] iteration 26800 [1557.74 sec]: learning rate : 0.000050 loss : 0.332149 +[06:14:24.017] iteration 26900 [1644.15 sec]: learning rate : 0.000050 loss : 0.337152 +[06:15:50.341] iteration 27000 [1730.45 sec]: learning rate : 0.000050 loss : 0.732571 +[06:16:58.514] Epoch 12 Evaluation: +[06:17:48.220] average MSE: 0.04103581979870796 average PSNR: 29.477005382571495 average SSIM: 0.7262191319676624 +[06:18:06.575] iteration 27100 [18.29 sec]: learning rate : 0.000050 loss : 0.477075 +[06:19:32.981] iteration 27200 [104.70 sec]: learning rate : 0.000050 loss : 0.524073 +[06:20:59.351] iteration 27300 [191.13 sec]: learning rate : 0.000050 loss : 0.467193 +[06:22:25.669] iteration 27400 [277.39 sec]: learning rate : 0.000050 loss : 0.320577 +[06:23:52.017] iteration 27500 [363.73 sec]: learning rate : 0.000050 loss : 0.440728 +[06:25:18.354] iteration 27600 [450.07 sec]: learning rate : 0.000050 loss : 0.535586 +[06:26:44.740] iteration 27700 [536.46 sec]: learning rate : 0.000050 loss : 0.683070 +[06:28:11.146] iteration 27800 [622.86 sec]: learning rate : 0.000050 loss : 0.563552 +[06:29:37.481] iteration 27900 [709.20 sec]: learning rate : 0.000050 loss : 0.477667 +[06:31:03.900] iteration 28000 [795.62 sec]: learning rate : 0.000050 loss : 0.376754 +[06:32:30.334] iteration 28100 [882.05 sec]: learning rate : 0.000050 loss : 0.654584 +[06:33:56.694] iteration 28200 [968.41 sec]: learning rate : 0.000050 loss : 0.417030 +[06:35:23.149] iteration 28300 [1054.87 sec]: learning rate : 0.000050 loss : 0.606362 +[06:36:49.564] iteration 28400 [1141.28 sec]: learning rate : 0.000050 loss : 0.466398 +[06:38:15.896] iteration 28500 [1227.61 sec]: learning rate : 0.000050 loss : 0.393928 +[06:39:42.300] iteration 28600 [1314.02 sec]: learning rate : 0.000050 loss : 0.277147 +[06:41:08.653] iteration 28700 [1400.37 sec]: learning rate : 0.000050 loss : 0.610407 +[06:42:34.981] iteration 28800 [1486.70 sec]: learning rate : 0.000050 loss : 0.286592 +[06:44:01.379] iteration 28900 [1573.10 sec]: learning rate : 0.000050 loss : 0.499104 +[06:45:27.721] iteration 29000 [1659.44 sec]: learning rate : 0.000050 loss : 0.566345 +[06:46:54.092] iteration 29100 [1745.81 sec]: learning rate : 0.000050 loss : 0.416112 +[06:47:47.603] Epoch 13 Evaluation: +[06:48:38.587] average MSE: 0.04174678400158882 average PSNR: 29.42609518020312 average SSIM: 0.7273784125196615 +[06:49:11.717] iteration 29200 [33.07 sec]: learning rate : 0.000050 loss : 0.355736 +[06:50:37.974] iteration 29300 [119.33 sec]: learning rate : 0.000050 loss : 0.523761 +[06:52:04.337] iteration 29400 [205.69 sec]: learning rate : 0.000050 loss : 0.427152 +[06:53:30.700] iteration 29500 [292.05 sec]: learning rate : 0.000050 loss : 0.440177 +[06:54:57.024] iteration 29600 [378.37 sec]: learning rate : 0.000050 loss : 0.331404 +[06:56:23.373] iteration 29700 [464.72 sec]: learning rate : 0.000050 loss : 0.419263 +[06:57:49.753] iteration 29800 [551.10 sec]: learning rate : 0.000050 loss : 0.685297 +[06:59:16.064] iteration 29900 [637.41 sec]: learning rate : 0.000050 loss : 0.485479 +[07:00:42.412] iteration 30000 [723.76 sec]: learning rate : 0.000050 loss : 0.365982 +[07:02:08.792] iteration 30100 [810.14 sec]: learning rate : 0.000050 loss : 0.634789 +[07:03:35.114] iteration 30200 [896.46 sec]: learning rate : 0.000050 loss : 0.760280 +[07:05:01.479] iteration 30300 [982.83 sec]: learning rate : 0.000050 loss : 0.510593 +[07:06:27.799] iteration 30400 [1069.15 sec]: learning rate : 0.000050 loss : 0.557650 +[07:07:54.181] iteration 30500 [1155.53 sec]: learning rate : 0.000050 loss : 0.289315 +[07:09:20.535] iteration 30600 [1241.91 sec]: learning rate : 0.000050 loss : 0.525229 +[07:10:46.868] iteration 30700 [1328.22 sec]: learning rate : 0.000050 loss : 0.461378 +[07:12:13.248] iteration 30800 [1414.60 sec]: learning rate : 0.000050 loss : 0.335060 +[07:13:39.626] iteration 30900 [1500.98 sec]: learning rate : 0.000050 loss : 0.400410 +[07:15:05.939] iteration 31000 [1587.29 sec]: learning rate : 0.000050 loss : 0.626410 +[07:16:32.309] iteration 31100 [1673.66 sec]: learning rate : 0.000050 loss : 0.618775 +[07:17:58.680] iteration 31200 [1760.03 sec]: learning rate : 0.000050 loss : 0.837674 +[07:18:37.502] Epoch 14 Evaluation: +[07:19:27.885] average MSE: 0.041153065860271454 average PSNR: 29.47814930784136 average SSIM: 0.7275911043829485 +[07:20:15.584] iteration 31300 [47.64 sec]: learning rate : 0.000050 loss : 0.399871 +[07:21:42.013] iteration 31400 [134.06 sec]: learning rate : 0.000050 loss : 0.379966 +[07:23:08.380] iteration 31500 [220.43 sec]: learning rate : 0.000050 loss : 0.609327 +[07:24:34.688] iteration 31600 [306.74 sec]: learning rate : 0.000050 loss : 0.392611 +[07:26:01.058] iteration 31700 [393.11 sec]: learning rate : 0.000050 loss : 0.495254 +[07:27:27.398] iteration 31800 [479.45 sec]: learning rate : 0.000050 loss : 0.457150 +[07:28:53.805] iteration 31900 [565.86 sec]: learning rate : 0.000050 loss : 0.521354 +[07:30:20.179] iteration 32000 [652.23 sec]: learning rate : 0.000050 loss : 0.417372 +[07:31:46.493] iteration 32100 [738.55 sec]: learning rate : 0.000050 loss : 0.509083 +[07:33:12.899] iteration 32200 [824.95 sec]: learning rate : 0.000050 loss : 0.400464 +[07:34:39.234] iteration 32300 [911.29 sec]: learning rate : 0.000050 loss : 0.593222 +[07:36:05.618] iteration 32400 [997.67 sec]: learning rate : 0.000050 loss : 0.486109 +[07:37:32.026] iteration 32500 [1084.08 sec]: learning rate : 0.000050 loss : 0.525904 +[07:38:58.385] iteration 32600 [1170.44 sec]: learning rate : 0.000050 loss : 0.421076 +[07:40:24.780] iteration 32700 [1256.83 sec]: learning rate : 0.000050 loss : 0.493514 +[07:41:51.183] iteration 32800 [1343.24 sec]: learning rate : 0.000050 loss : 0.679569 +[07:43:17.521] iteration 32900 [1429.57 sec]: learning rate : 0.000050 loss : 0.398449 +[07:44:43.901] iteration 33000 [1515.95 sec]: learning rate : 0.000050 loss : 0.562773 +[07:46:10.277] iteration 33100 [1602.33 sec]: learning rate : 0.000050 loss : 0.355039 +[07:47:36.608] iteration 33200 [1688.66 sec]: learning rate : 0.000050 loss : 0.680743 +[07:49:02.988] iteration 33300 [1775.04 sec]: learning rate : 0.000050 loss : 0.520148 +[07:49:27.123] Epoch 15 Evaluation: +[07:50:16.810] average MSE: 0.04141544923186302 average PSNR: 29.471050856970276 average SSIM: 0.7290272455339905 +[07:51:19.284] iteration 33400 [62.41 sec]: learning rate : 0.000050 loss : 0.227212 +[07:52:45.595] iteration 33500 [148.72 sec]: learning rate : 0.000050 loss : 0.568949 +[07:54:11.931] iteration 33600 [235.06 sec]: learning rate : 0.000050 loss : 0.343168 +[07:55:38.236] iteration 33700 [321.36 sec]: learning rate : 0.000050 loss : 0.349878 +[07:57:04.588] iteration 33800 [407.72 sec]: learning rate : 0.000050 loss : 0.470890 +[07:58:30.903] iteration 33900 [494.03 sec]: learning rate : 0.000050 loss : 0.405791 +[07:59:57.231] iteration 34000 [580.36 sec]: learning rate : 0.000050 loss : 0.768749 +[08:01:23.572] iteration 34100 [666.70 sec]: learning rate : 0.000050 loss : 0.517028 +[08:02:49.864] iteration 34200 [752.99 sec]: learning rate : 0.000050 loss : 0.452437 +[08:04:16.174] iteration 34300 [839.31 sec]: learning rate : 0.000050 loss : 0.457513 +[08:05:42.538] iteration 34400 [925.67 sec]: learning rate : 0.000050 loss : 0.541207 +[08:07:08.826] iteration 34500 [1011.95 sec]: learning rate : 0.000050 loss : 0.418657 +[08:08:35.119] iteration 34600 [1098.25 sec]: learning rate : 0.000050 loss : 0.655931 +[08:10:01.383] iteration 34700 [1184.51 sec]: learning rate : 0.000050 loss : 0.521759 +[08:11:27.729] iteration 34800 [1270.86 sec]: learning rate : 0.000050 loss : 0.966353 +[08:12:54.082] iteration 34900 [1357.21 sec]: learning rate : 0.000050 loss : 0.312306 +[08:14:20.341] iteration 35000 [1443.47 sec]: learning rate : 0.000050 loss : 0.543617 +[08:15:46.638] iteration 35100 [1529.77 sec]: learning rate : 0.000050 loss : 0.723271 +[08:17:12.866] iteration 35200 [1615.99 sec]: learning rate : 0.000050 loss : 0.467648 +[08:18:39.164] iteration 35300 [1702.29 sec]: learning rate : 0.000050 loss : 0.469696 +[08:20:05.446] iteration 35400 [1788.58 sec]: learning rate : 0.000050 loss : 0.343912 +[08:20:14.898] Epoch 16 Evaluation: +[08:21:04.555] average MSE: 0.04076014831662178 average PSNR: 29.54148054134856 average SSIM: 0.7301459840382901 +[08:22:21.564] iteration 35500 [76.95 sec]: learning rate : 0.000050 loss : 0.530869 +[08:23:47.960] iteration 35600 [163.34 sec]: learning rate : 0.000050 loss : 0.655292 +[08:25:14.338] iteration 35700 [249.72 sec]: learning rate : 0.000050 loss : 0.380121 +[08:26:40.670] iteration 35800 [336.05 sec]: learning rate : 0.000050 loss : 0.227082 +[08:28:07.061] iteration 35900 [422.44 sec]: learning rate : 0.000050 loss : 0.558068 +[08:29:33.457] iteration 36000 [508.84 sec]: learning rate : 0.000050 loss : 0.606173 +[08:30:59.786] iteration 36100 [595.17 sec]: learning rate : 0.000050 loss : 0.470676 +[08:32:26.178] iteration 36200 [681.56 sec]: learning rate : 0.000050 loss : 0.588330 +[08:33:52.591] iteration 36300 [767.99 sec]: learning rate : 0.000050 loss : 0.411750 +[08:35:18.915] iteration 36400 [854.30 sec]: learning rate : 0.000050 loss : 0.426800 +[08:36:45.287] iteration 36500 [940.67 sec]: learning rate : 0.000050 loss : 0.427112 +[08:38:11.633] iteration 36600 [1027.02 sec]: learning rate : 0.000050 loss : 0.350164 +[08:39:38.041] iteration 36700 [1113.42 sec]: learning rate : 0.000050 loss : 0.400070 +[08:41:04.474] iteration 36800 [1199.86 sec]: learning rate : 0.000050 loss : 0.495335 +[08:42:30.909] iteration 36900 [1286.29 sec]: learning rate : 0.000050 loss : 0.542840 +[08:43:57.295] iteration 37000 [1372.68 sec]: learning rate : 0.000050 loss : 0.490900 +[08:45:23.762] iteration 37100 [1459.14 sec]: learning rate : 0.000050 loss : 0.342629 +[08:46:50.149] iteration 37200 [1545.53 sec]: learning rate : 0.000050 loss : 0.636382 +[08:48:16.605] iteration 37300 [1631.99 sec]: learning rate : 0.000050 loss : 0.815241 +[08:49:43.103] iteration 37400 [1718.49 sec]: learning rate : 0.000050 loss : 0.436120 +[08:51:04.320] Epoch 17 Evaluation: +[08:51:54.323] average MSE: 0.04110564664006233 average PSNR: 29.509501717786893 average SSIM: 0.7301773727083513 +[08:51:59.760] iteration 37500 [5.38 sec]: learning rate : 0.000050 loss : 0.412250 +[08:53:26.240] iteration 37600 [91.85 sec]: learning rate : 0.000050 loss : 0.513170 +[08:54:52.706] iteration 37700 [178.32 sec]: learning rate : 0.000050 loss : 0.261322 +[08:56:19.129] iteration 37800 [264.75 sec]: learning rate : 0.000050 loss : 0.310563 +[08:57:45.615] iteration 37900 [351.23 sec]: learning rate : 0.000050 loss : 0.399285 +[08:59:12.121] iteration 38000 [437.73 sec]: learning rate : 0.000050 loss : 0.355199 +[09:00:38.559] iteration 38100 [524.17 sec]: learning rate : 0.000050 loss : 0.449760 +[09:02:05.063] iteration 38200 [610.68 sec]: learning rate : 0.000050 loss : 0.570206 +[09:03:31.554] iteration 38300 [697.17 sec]: learning rate : 0.000050 loss : 0.335430 +[09:04:57.982] iteration 38400 [783.60 sec]: learning rate : 0.000050 loss : 0.256574 +[09:06:24.474] iteration 38500 [870.09 sec]: learning rate : 0.000050 loss : 0.740203 +[09:07:50.977] iteration 38600 [956.59 sec]: learning rate : 0.000050 loss : 0.430809 +[09:09:17.409] iteration 38700 [1043.02 sec]: learning rate : 0.000050 loss : 0.368418 +[09:10:43.926] iteration 38800 [1129.54 sec]: learning rate : 0.000050 loss : 0.391729 +[09:12:10.367] iteration 38900 [1215.98 sec]: learning rate : 0.000050 loss : 0.335248 +[09:13:36.862] iteration 39000 [1302.48 sec]: learning rate : 0.000050 loss : 0.332407 +[09:15:03.377] iteration 39100 [1388.99 sec]: learning rate : 0.000050 loss : 0.399490 +[09:16:29.833] iteration 39200 [1475.45 sec]: learning rate : 0.000050 loss : 0.635365 +[09:17:56.321] iteration 39300 [1561.94 sec]: learning rate : 0.000050 loss : 0.595150 +[09:19:22.812] iteration 39400 [1648.43 sec]: learning rate : 0.000050 loss : 0.347131 +[09:20:49.242] iteration 39500 [1734.86 sec]: learning rate : 0.000050 loss : 0.679169 +[09:21:55.824] Epoch 18 Evaluation: +[09:22:46.897] average MSE: 0.04108226299285889 average PSNR: 29.500863698802945 average SSIM: 0.7297674378318537 +[09:23:06.987] iteration 39600 [20.03 sec]: learning rate : 0.000050 loss : 0.431614 +[09:24:33.330] iteration 39700 [106.37 sec]: learning rate : 0.000050 loss : 0.442770 +[09:25:59.604] iteration 39800 [192.64 sec]: learning rate : 0.000050 loss : 0.437596 +[09:27:25.927] iteration 39900 [278.97 sec]: learning rate : 0.000050 loss : 0.376981 +[09:28:52.286] iteration 40000 [365.33 sec]: learning rate : 0.000013 loss : 0.431217 +[09:28:52.450] save model to model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion/iter_40000.pth +[09:30:18.745] iteration 40100 [451.79 sec]: learning rate : 0.000025 loss : 0.473692 +[09:31:45.080] iteration 40200 [538.12 sec]: learning rate : 0.000025 loss : 0.641180 +[09:33:11.427] iteration 40300 [624.47 sec]: learning rate : 0.000025 loss : 0.510182 +[09:34:37.706] iteration 40400 [710.75 sec]: learning rate : 0.000025 loss : 0.608605 +[09:36:04.049] iteration 40500 [797.09 sec]: learning rate : 0.000025 loss : 0.630905 +[09:37:30.421] iteration 40600 [883.46 sec]: learning rate : 0.000025 loss : 0.544204 +[09:38:56.710] iteration 40700 [969.75 sec]: learning rate : 0.000025 loss : 0.390633 +[09:40:23.049] iteration 40800 [1056.09 sec]: learning rate : 0.000025 loss : 0.521406 +[09:41:49.353] iteration 40900 [1142.39 sec]: learning rate : 0.000025 loss : 0.571820 +[09:43:15.645] iteration 41000 [1228.69 sec]: learning rate : 0.000025 loss : 0.550191 +[09:44:42.002] iteration 41100 [1315.04 sec]: learning rate : 0.000025 loss : 0.407522 +[09:46:08.342] iteration 41200 [1401.38 sec]: learning rate : 0.000025 loss : 0.491996 +[09:47:34.632] iteration 41300 [1487.67 sec]: learning rate : 0.000025 loss : 0.405697 +[09:49:00.977] iteration 41400 [1574.02 sec]: learning rate : 0.000025 loss : 0.497828 +[09:50:27.276] iteration 41500 [1660.32 sec]: learning rate : 0.000025 loss : 0.383086 +[09:51:53.614] iteration 41600 [1746.65 sec]: learning rate : 0.000025 loss : 0.228946 +[09:52:45.359] Epoch 19 Evaluation: +[09:53:36.351] average MSE: 0.04011804237961769 average PSNR: 29.621978647169204 average SSIM: 0.7303459233600961 +[09:54:11.090] iteration 41700 [34.68 sec]: learning rate : 0.000025 loss : 0.469195 +[09:55:37.350] iteration 41800 [120.94 sec]: learning rate : 0.000025 loss : 0.574088 +[09:57:03.694] iteration 41900 [207.28 sec]: learning rate : 0.000025 loss : 0.390816 +[09:58:30.037] iteration 42000 [293.62 sec]: learning rate : 0.000025 loss : 0.495360 +[09:59:56.319] iteration 42100 [379.90 sec]: learning rate : 0.000025 loss : 0.400117 +[10:01:22.634] iteration 42200 [466.22 sec]: learning rate : 0.000025 loss : 0.422953 +[10:02:48.943] iteration 42300 [552.53 sec]: learning rate : 0.000025 loss : 0.727552 +[10:04:15.296] iteration 42400 [638.88 sec]: learning rate : 0.000025 loss : 0.703930 +[10:05:41.630] iteration 42500 [725.22 sec]: learning rate : 0.000025 loss : 0.412949 +[10:07:07.957] iteration 42600 [811.54 sec]: learning rate : 0.000025 loss : 0.699826 +[10:08:34.404] iteration 42700 [897.99 sec]: learning rate : 0.000025 loss : 0.512419 +[10:10:00.730] iteration 42800 [984.32 sec]: learning rate : 0.000025 loss : 0.871046 +[10:11:27.114] iteration 42900 [1070.70 sec]: learning rate : 0.000025 loss : 0.300010 +[10:12:53.506] iteration 43000 [1157.09 sec]: learning rate : 0.000025 loss : 0.353374 +[10:14:19.826] iteration 43100 [1243.41 sec]: learning rate : 0.000025 loss : 0.253845 +[10:15:46.174] iteration 43200 [1329.76 sec]: learning rate : 0.000025 loss : 0.476577 +[10:17:12.497] iteration 43300 [1416.08 sec]: learning rate : 0.000025 loss : 0.475726 +[10:18:38.877] iteration 43400 [1502.46 sec]: learning rate : 0.000025 loss : 0.299287 +[10:20:05.240] iteration 43500 [1588.83 sec]: learning rate : 0.000025 loss : 0.435919 +[10:21:31.563] iteration 43600 [1675.15 sec]: learning rate : 0.000025 loss : 0.646042 +[10:22:57.932] iteration 43700 [1761.52 sec]: learning rate : 0.000025 loss : 0.779587 +[10:23:35.021] Epoch 20 Evaluation: +[10:24:24.503] average MSE: 0.040429577231407166 average PSNR: 29.597547217888224 average SSIM: 0.7300813788551914 +[10:25:14.038] iteration 43800 [49.47 sec]: learning rate : 0.000025 loss : 0.455693 +[10:26:40.328] iteration 43900 [135.76 sec]: learning rate : 0.000025 loss : 0.471236 +[10:28:06.665] iteration 44000 [222.10 sec]: learning rate : 0.000025 loss : 0.312959 +[10:29:32.994] iteration 44100 [308.43 sec]: learning rate : 0.000025 loss : 0.350395 +[10:30:59.343] iteration 44200 [394.78 sec]: learning rate : 0.000025 loss : 0.639531 +[10:32:25.766] iteration 44300 [481.20 sec]: learning rate : 0.000025 loss : 0.350974 +[10:33:52.069] iteration 44400 [567.50 sec]: learning rate : 0.000025 loss : 0.319982 +[10:35:18.438] iteration 44500 [653.89 sec]: learning rate : 0.000025 loss : 0.572275 +[10:36:44.824] iteration 44600 [740.26 sec]: learning rate : 0.000025 loss : 0.500600 +[10:38:11.126] iteration 44700 [826.56 sec]: learning rate : 0.000025 loss : 0.274631 +[10:39:37.495] iteration 44800 [912.93 sec]: learning rate : 0.000025 loss : 0.782294 +[10:41:03.866] iteration 44900 [999.30 sec]: learning rate : 0.000025 loss : 0.651744 +[10:42:30.181] iteration 45000 [1085.61 sec]: learning rate : 0.000025 loss : 0.442358 +[10:43:56.540] iteration 45100 [1171.97 sec]: learning rate : 0.000025 loss : 0.414184 +[10:45:22.856] iteration 45200 [1258.29 sec]: learning rate : 0.000025 loss : 0.479798 +[10:46:49.207] iteration 45300 [1344.64 sec]: learning rate : 0.000025 loss : 0.443196 +[10:48:15.578] iteration 45400 [1431.01 sec]: learning rate : 0.000025 loss : 0.377541 +[10:49:41.897] iteration 45500 [1517.33 sec]: learning rate : 0.000025 loss : 0.404760 +[10:51:08.245] iteration 45600 [1603.68 sec]: learning rate : 0.000025 loss : 0.529053 +[10:52:34.621] iteration 45700 [1690.06 sec]: learning rate : 0.000025 loss : 0.723548 +[10:54:00.918] iteration 45800 [1776.35 sec]: learning rate : 0.000025 loss : 0.565349 +[10:54:23.341] Epoch 21 Evaluation: +[10:55:12.505] average MSE: 0.040435925126075745 average PSNR: 29.59750601263131 average SSIM: 0.7319236445884675 +[10:56:16.725] iteration 45900 [64.16 sec]: learning rate : 0.000025 loss : 0.566175 +[10:57:42.998] iteration 46000 [150.43 sec]: learning rate : 0.000025 loss : 0.351117 +[10:59:09.305] iteration 46100 [236.74 sec]: learning rate : 0.000025 loss : 0.643690 +[11:00:35.666] iteration 46200 [323.10 sec]: learning rate : 0.000025 loss : 0.473406 +[11:02:01.984] iteration 46300 [409.42 sec]: learning rate : 0.000025 loss : 0.418713 +[11:03:28.351] iteration 46400 [495.80 sec]: learning rate : 0.000025 loss : 0.798176 +[11:04:54.715] iteration 46500 [582.15 sec]: learning rate : 0.000025 loss : 0.287205 +[11:06:21.019] iteration 46600 [668.45 sec]: learning rate : 0.000025 loss : 0.522081 +[11:07:47.418] iteration 46700 [754.85 sec]: learning rate : 0.000025 loss : 0.441123 +[11:09:13.796] iteration 46800 [841.23 sec]: learning rate : 0.000025 loss : 0.580453 +[11:10:40.126] iteration 46900 [927.56 sec]: learning rate : 0.000025 loss : 0.471597 +[11:12:06.493] iteration 47000 [1013.92 sec]: learning rate : 0.000025 loss : 0.376851 +[11:13:32.813] iteration 47100 [1100.25 sec]: learning rate : 0.000025 loss : 0.492584 +[11:14:59.206] iteration 47200 [1186.64 sec]: learning rate : 0.000025 loss : 0.649696 +[11:16:25.626] iteration 47300 [1273.06 sec]: learning rate : 0.000025 loss : 0.670292 +[11:17:51.961] iteration 47400 [1359.39 sec]: learning rate : 0.000025 loss : 0.453381 +[11:19:18.368] iteration 47500 [1445.80 sec]: learning rate : 0.000025 loss : 0.587914 +[11:20:44.723] iteration 47600 [1532.16 sec]: learning rate : 0.000025 loss : 0.394305 +[11:22:11.037] iteration 47700 [1618.47 sec]: learning rate : 0.000025 loss : 0.475338 +[11:23:37.392] iteration 47800 [1704.82 sec]: learning rate : 0.000025 loss : 0.423941 +[11:25:03.748] iteration 47900 [1791.18 sec]: learning rate : 0.000025 loss : 0.534656 +[11:25:11.501] Epoch 22 Evaluation: +[11:26:02.382] average MSE: 0.04035136476159096 average PSNR: 29.605522979665967 average SSIM: 0.7315147814286499 +[11:27:21.117] iteration 48000 [78.67 sec]: learning rate : 0.000025 loss : 0.595057 +[11:28:47.558] iteration 48100 [165.11 sec]: learning rate : 0.000025 loss : 0.565298 +[11:30:13.887] iteration 48200 [251.44 sec]: learning rate : 0.000025 loss : 0.579396 +[11:31:40.274] iteration 48300 [337.83 sec]: learning rate : 0.000025 loss : 0.279025 +[11:33:06.670] iteration 48400 [424.23 sec]: learning rate : 0.000025 loss : 0.435553 +[11:34:33.018] iteration 48500 [510.57 sec]: learning rate : 0.000025 loss : 0.531974 +[11:35:59.385] iteration 48600 [596.94 sec]: learning rate : 0.000025 loss : 0.379668 +[11:37:25.756] iteration 48700 [683.31 sec]: learning rate : 0.000025 loss : 0.459625 +[11:38:52.058] iteration 48800 [769.61 sec]: learning rate : 0.000025 loss : 0.292094 +[11:40:18.423] iteration 48900 [855.98 sec]: learning rate : 0.000025 loss : 0.532548 +[11:41:44.828] iteration 49000 [942.38 sec]: learning rate : 0.000025 loss : 0.522658 +[11:43:11.168] iteration 49100 [1028.72 sec]: learning rate : 0.000025 loss : 0.342265 +[11:44:37.560] iteration 49200 [1115.12 sec]: learning rate : 0.000025 loss : 0.691609 +[11:46:03.980] iteration 49300 [1201.54 sec]: learning rate : 0.000025 loss : 0.455292 +[11:47:30.328] iteration 49400 [1287.88 sec]: learning rate : 0.000025 loss : 0.523370 +[11:48:56.737] iteration 49500 [1374.29 sec]: learning rate : 0.000025 loss : 0.699895 +[11:50:23.140] iteration 49600 [1460.70 sec]: learning rate : 0.000025 loss : 0.298364 +[11:51:49.488] iteration 49700 [1547.04 sec]: learning rate : 0.000025 loss : 0.719678 +[11:53:15.896] iteration 49800 [1633.45 sec]: learning rate : 0.000025 loss : 0.444734 +[11:54:42.242] iteration 49900 [1719.80 sec]: learning rate : 0.000025 loss : 0.351121 +[11:56:01.704] Epoch 23 Evaluation: +[11:56:51.457] average MSE: 0.04013755917549133 average PSNR: 29.628876486878564 average SSIM: 0.7315756079516257 +[11:56:58.593] iteration 50000 [7.07 sec]: learning rate : 0.000025 loss : 0.467555 +[11:58:25.026] iteration 50100 [93.51 sec]: learning rate : 0.000025 loss : 0.719309 +[11:59:51.379] iteration 50200 [179.86 sec]: learning rate : 0.000025 loss : 0.536575 +[12:01:17.703] iteration 50300 [266.18 sec]: learning rate : 0.000025 loss : 0.603630 +[12:02:44.072] iteration 50400 [352.55 sec]: learning rate : 0.000025 loss : 0.497382 +[12:04:10.449] iteration 50500 [438.99 sec]: learning rate : 0.000025 loss : 0.458399 +[12:05:36.774] iteration 50600 [525.26 sec]: learning rate : 0.000025 loss : 0.481237 +[12:07:03.192] iteration 50700 [611.67 sec]: learning rate : 0.000025 loss : 0.758374 +[12:08:29.541] iteration 50800 [698.02 sec]: learning rate : 0.000025 loss : 0.327972 +[12:09:55.979] iteration 50900 [784.46 sec]: learning rate : 0.000025 loss : 0.346489 +[12:11:22.393] iteration 51000 [870.87 sec]: learning rate : 0.000025 loss : 0.546119 +[12:12:48.768] iteration 51100 [957.25 sec]: learning rate : 0.000025 loss : 0.503415 +[12:14:15.182] iteration 51200 [1043.66 sec]: learning rate : 0.000025 loss : 0.386414 +[12:15:41.596] iteration 51300 [1130.08 sec]: learning rate : 0.000025 loss : 0.289618 +[12:17:07.972] iteration 51400 [1216.45 sec]: learning rate : 0.000025 loss : 0.482045 +[12:18:34.392] iteration 51500 [1302.87 sec]: learning rate : 0.000025 loss : 0.259252 +[12:20:00.774] iteration 51600 [1389.25 sec]: learning rate : 0.000025 loss : 0.490257 +[12:21:27.226] iteration 51700 [1475.71 sec]: learning rate : 0.000025 loss : 0.198756 +[12:22:53.648] iteration 51800 [1562.13 sec]: learning rate : 0.000025 loss : 0.378444 +[12:24:20.031] iteration 51900 [1648.51 sec]: learning rate : 0.000025 loss : 0.726749 +[12:25:46.485] iteration 52000 [1734.96 sec]: learning rate : 0.000025 loss : 0.349273 +[12:26:51.257] Epoch 24 Evaluation: +[12:27:41.131] average MSE: 0.04001776874065399 average PSNR: 29.65282902913081 average SSIM: 0.7322920112295898 +[12:28:02.953] iteration 52100 [21.76 sec]: learning rate : 0.000025 loss : 0.640960 +[12:29:29.438] iteration 52200 [108.24 sec]: learning rate : 0.000025 loss : 0.355006 +[12:30:55.878] iteration 52300 [194.69 sec]: learning rate : 0.000025 loss : 0.597864 +[12:32:22.259] iteration 52400 [281.07 sec]: learning rate : 0.000025 loss : 0.392446 +[12:33:48.721] iteration 52500 [367.53 sec]: learning rate : 0.000025 loss : 0.456238 +[12:35:15.170] iteration 52600 [453.98 sec]: learning rate : 0.000025 loss : 0.639171 +[12:36:41.571] iteration 52700 [540.38 sec]: learning rate : 0.000025 loss : 0.376995 +[12:38:08.042] iteration 52800 [626.85 sec]: learning rate : 0.000025 loss : 0.386344 +[12:39:34.453] iteration 52900 [713.26 sec]: learning rate : 0.000025 loss : 0.624368 +[12:41:00.940] iteration 53000 [799.75 sec]: learning rate : 0.000025 loss : 0.641141 +[12:42:27.403] iteration 53100 [886.21 sec]: learning rate : 0.000025 loss : 0.470654 +[12:43:53.820] iteration 53200 [972.63 sec]: learning rate : 0.000025 loss : 0.377502 +[12:45:20.303] iteration 53300 [1059.11 sec]: learning rate : 0.000025 loss : 0.357791 +[12:46:46.786] iteration 53400 [1145.59 sec]: learning rate : 0.000025 loss : 0.412684 +[12:48:13.214] iteration 53500 [1232.02 sec]: learning rate : 0.000025 loss : 0.713839 +[12:49:39.716] iteration 53600 [1318.52 sec]: learning rate : 0.000025 loss : 0.670644 +[12:51:06.198] iteration 53700 [1405.00 sec]: learning rate : 0.000025 loss : 0.313578 +[12:52:32.617] iteration 53800 [1491.42 sec]: learning rate : 0.000025 loss : 0.503427 +[12:53:59.066] iteration 53900 [1577.87 sec]: learning rate : 0.000025 loss : 0.304087 +[12:55:25.505] iteration 54000 [1664.31 sec]: learning rate : 0.000025 loss : 0.757058 +[12:56:51.947] iteration 54100 [1750.75 sec]: learning rate : 0.000025 loss : 0.295556 +[12:57:42.074] Epoch 25 Evaluation: +[12:58:31.615] average MSE: 0.040289636701345444 average PSNR: 29.63123277318829 average SSIM: 0.732790255177136 +[12:59:08.265] iteration 54200 [36.59 sec]: learning rate : 0.000025 loss : 0.482293 +[13:00:34.706] iteration 54300 [123.03 sec]: learning rate : 0.000025 loss : 0.501643 +[13:02:01.244] iteration 54400 [209.57 sec]: learning rate : 0.000025 loss : 0.442793 +[13:03:27.758] iteration 54500 [296.08 sec]: learning rate : 0.000025 loss : 0.701870 +[13:04:54.216] iteration 54600 [382.54 sec]: learning rate : 0.000025 loss : 0.834805 +[13:06:20.747] iteration 54700 [469.07 sec]: learning rate : 0.000025 loss : 0.484366 +[13:07:47.292] iteration 54800 [555.61 sec]: learning rate : 0.000025 loss : 0.599442 +[13:09:13.774] iteration 54900 [642.10 sec]: learning rate : 0.000025 loss : 0.641012 +[13:10:40.307] iteration 55000 [728.63 sec]: learning rate : 0.000025 loss : 0.346707 +[13:12:06.851] iteration 55100 [815.17 sec]: learning rate : 0.000025 loss : 0.291528 +[13:13:33.329] iteration 55200 [901.65 sec]: learning rate : 0.000025 loss : 0.289852 +[13:14:59.843] iteration 55300 [988.16 sec]: learning rate : 0.000025 loss : 0.390987 +[13:16:26.314] iteration 55400 [1074.64 sec]: learning rate : 0.000025 loss : 0.518253 +[13:17:52.816] iteration 55500 [1161.14 sec]: learning rate : 0.000025 loss : 0.479555 +[13:19:19.332] iteration 55600 [1247.65 sec]: learning rate : 0.000025 loss : 0.837368 +[13:20:45.786] iteration 55700 [1334.11 sec]: learning rate : 0.000025 loss : 0.398033 +[13:22:12.297] iteration 55800 [1420.62 sec]: learning rate : 0.000025 loss : 0.481792 +[13:23:38.807] iteration 55900 [1507.13 sec]: learning rate : 0.000025 loss : 0.356895 +[13:25:05.248] iteration 56000 [1593.57 sec]: learning rate : 0.000025 loss : 0.744133 +[13:26:31.752] iteration 56100 [1680.07 sec]: learning rate : 0.000025 loss : 0.582192 +[13:27:58.286] iteration 56200 [1766.61 sec]: learning rate : 0.000025 loss : 0.395901 +[13:28:33.725] Epoch 26 Evaluation: +[13:29:23.092] average MSE: 0.03991933539509773 average PSNR: 29.674105000685653 average SSIM: 0.7331039330753791 +[13:30:14.314] iteration 56300 [51.16 sec]: learning rate : 0.000025 loss : 0.412889 +[13:31:40.811] iteration 56400 [137.66 sec]: learning rate : 0.000025 loss : 0.582817 +[13:33:07.279] iteration 56500 [224.12 sec]: learning rate : 0.000025 loss : 0.207373 +[13:34:33.718] iteration 56600 [310.56 sec]: learning rate : 0.000025 loss : 0.417043 +[13:36:00.216] iteration 56700 [397.06 sec]: learning rate : 0.000025 loss : 0.518182 +[13:37:26.746] iteration 56800 [483.59 sec]: learning rate : 0.000025 loss : 0.857378 +[13:38:53.159] iteration 56900 [570.00 sec]: learning rate : 0.000025 loss : 0.686080 +[13:40:19.633] iteration 57000 [656.48 sec]: learning rate : 0.000025 loss : 0.359141 +[13:41:46.062] iteration 57100 [742.91 sec]: learning rate : 0.000025 loss : 0.513015 +[13:43:12.496] iteration 57200 [829.34 sec]: learning rate : 0.000025 loss : 0.344110 +[13:44:38.946] iteration 57300 [915.79 sec]: learning rate : 0.000025 loss : 0.480840 +[13:46:05.335] iteration 57400 [1002.18 sec]: learning rate : 0.000025 loss : 0.410470 +[13:47:31.770] iteration 57500 [1088.61 sec]: learning rate : 0.000025 loss : 0.821050 +[13:48:58.181] iteration 57600 [1175.03 sec]: learning rate : 0.000025 loss : 0.416297 +[13:50:24.560] iteration 57700 [1261.40 sec]: learning rate : 0.000025 loss : 0.402142 +[13:51:51.009] iteration 57800 [1347.86 sec]: learning rate : 0.000025 loss : 0.476532 +[13:53:17.433] iteration 57900 [1434.28 sec]: learning rate : 0.000025 loss : 0.516403 +[13:54:43.816] iteration 58000 [1520.66 sec]: learning rate : 0.000025 loss : 0.371293 +[13:56:10.253] iteration 58100 [1607.10 sec]: learning rate : 0.000025 loss : 0.498779 +[13:57:36.709] iteration 58200 [1693.55 sec]: learning rate : 0.000025 loss : 0.386382 +[13:59:03.087] iteration 58300 [1779.93 sec]: learning rate : 0.000025 loss : 0.329116 +[13:59:23.797] Epoch 27 Evaluation: +[14:00:13.218] average MSE: 0.04028173163533211 average PSNR: 29.626323165070108 average SSIM: 0.7323039619701267 +[14:01:19.180] iteration 58400 [65.90 sec]: learning rate : 0.000025 loss : 0.468364 +[14:02:45.627] iteration 58500 [152.41 sec]: learning rate : 0.000025 loss : 0.394283 +[14:04:12.001] iteration 58600 [238.72 sec]: learning rate : 0.000025 loss : 0.729013 +[14:05:38.443] iteration 58700 [325.16 sec]: learning rate : 0.000025 loss : 0.561679 +[14:07:04.839] iteration 58800 [411.56 sec]: learning rate : 0.000025 loss : 0.369042 +[14:08:31.277] iteration 58900 [498.00 sec]: learning rate : 0.000025 loss : 0.450183 +[14:09:57.747] iteration 59000 [584.47 sec]: learning rate : 0.000025 loss : 0.360168 +[14:11:24.133] iteration 59100 [670.85 sec]: learning rate : 0.000025 loss : 0.484132 +[14:12:50.573] iteration 59200 [757.29 sec]: learning rate : 0.000025 loss : 0.526005 +[14:14:16.987] iteration 59300 [843.71 sec]: learning rate : 0.000025 loss : 0.745856 +[14:15:43.394] iteration 59400 [930.11 sec]: learning rate : 0.000025 loss : 0.335050 +[14:17:09.880] iteration 59500 [1016.60 sec]: learning rate : 0.000025 loss : 0.224885 +[14:18:36.314] iteration 59600 [1103.03 sec]: learning rate : 0.000025 loss : 0.338338 +[14:20:02.730] iteration 59700 [1189.45 sec]: learning rate : 0.000025 loss : 0.506696 +[14:21:29.201] iteration 59800 [1275.92 sec]: learning rate : 0.000025 loss : 0.573035 +[14:22:55.624] iteration 59900 [1362.34 sec]: learning rate : 0.000025 loss : 0.462757 +[14:24:22.079] iteration 60000 [1448.80 sec]: learning rate : 0.000006 loss : 0.564784 +[14:24:22.237] save model to model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion/iter_60000.pth +[14:25:48.685] iteration 60100 [1535.41 sec]: learning rate : 0.000013 loss : 0.305341 +[14:27:15.097] iteration 60200 [1621.82 sec]: learning rate : 0.000013 loss : 0.355793 +[14:28:41.559] iteration 60300 [1708.28 sec]: learning rate : 0.000013 loss : 0.603737 +[14:30:08.063] iteration 60400 [1794.78 sec]: learning rate : 0.000013 loss : 0.664747 +[14:30:14.102] Epoch 28 Evaluation: +[14:31:06.577] average MSE: 0.03977381810545921 average PSNR: 29.689213945071668 average SSIM: 0.7325541597382702 +[14:32:27.119] iteration 60500 [80.48 sec]: learning rate : 0.000013 loss : 0.767639 +[14:33:53.557] iteration 60600 [166.92 sec]: learning rate : 0.000013 loss : 0.375632 +[14:35:19.933] iteration 60700 [253.29 sec]: learning rate : 0.000013 loss : 0.544779 +[14:36:46.287] iteration 60800 [339.65 sec]: learning rate : 0.000013 loss : 0.180612 +[14:38:12.699] iteration 60900 [426.06 sec]: learning rate : 0.000013 loss : 0.676624 +[14:39:39.104] iteration 61000 [512.46 sec]: learning rate : 0.000013 loss : 0.371472 +[14:41:05.461] iteration 61100 [598.82 sec]: learning rate : 0.000013 loss : 0.546320 +[14:42:31.823] iteration 61200 [685.18 sec]: learning rate : 0.000013 loss : 0.297267 +[14:43:58.194] iteration 61300 [771.56 sec]: learning rate : 0.000013 loss : 0.606668 +[14:45:24.618] iteration 61400 [857.98 sec]: learning rate : 0.000013 loss : 0.579527 +[14:46:51.076] iteration 61500 [944.44 sec]: learning rate : 0.000013 loss : 0.483032 +[14:48:17.428] iteration 61600 [1030.79 sec]: learning rate : 0.000013 loss : 0.577978 +[14:49:43.837] iteration 61700 [1117.20 sec]: learning rate : 0.000013 loss : 0.241520 +[14:51:10.246] iteration 61800 [1203.60 sec]: learning rate : 0.000013 loss : 0.533256 +[14:52:36.603] iteration 61900 [1289.96 sec]: learning rate : 0.000013 loss : 0.449198 +[14:54:03.021] iteration 62000 [1376.38 sec]: learning rate : 0.000013 loss : 0.665353 +[14:55:29.398] iteration 62100 [1462.76 sec]: learning rate : 0.000013 loss : 0.746523 +[14:56:55.759] iteration 62200 [1549.12 sec]: learning rate : 0.000013 loss : 0.423584 +[14:58:22.205] iteration 62300 [1635.56 sec]: learning rate : 0.000013 loss : 0.347939 +[14:59:48.655] iteration 62400 [1722.01 sec]: learning rate : 0.000013 loss : 0.751172 +[15:01:06.368] Epoch 29 Evaluation: +[15:01:58.182] average MSE: 0.039816852658987045 average PSNR: 29.69398070653942 average SSIM: 0.7338917890457433 +[15:02:07.044] iteration 62500 [8.80 sec]: learning rate : 0.000013 loss : 0.374815 +[15:03:33.514] iteration 62600 [95.27 sec]: learning rate : 0.000013 loss : 0.342464 +[15:04:59.949] iteration 62700 [181.70 sec]: learning rate : 0.000013 loss : 0.459578 +[15:06:26.324] iteration 62800 [268.08 sec]: learning rate : 0.000013 loss : 0.406383 +[15:07:52.741] iteration 62900 [354.50 sec]: learning rate : 0.000013 loss : 0.466099 +[15:09:19.153] iteration 63000 [440.91 sec]: learning rate : 0.000013 loss : 0.477070 +[15:10:45.506] iteration 63100 [527.26 sec]: learning rate : 0.000013 loss : 0.504203 +[15:12:11.956] iteration 63200 [613.71 sec]: learning rate : 0.000013 loss : 0.451920 +[15:13:38.428] iteration 63300 [700.18 sec]: learning rate : 0.000013 loss : 0.365861 +[15:15:04.808] iteration 63400 [786.56 sec]: learning rate : 0.000013 loss : 0.646976 +[15:16:31.248] iteration 63500 [873.00 sec]: learning rate : 0.000013 loss : 0.504403 +[15:17:57.694] iteration 63600 [959.45 sec]: learning rate : 0.000013 loss : 0.299313 +[15:19:24.070] iteration 63700 [1045.83 sec]: learning rate : 0.000013 loss : 0.417621 +[15:20:50.491] iteration 63800 [1132.25 sec]: learning rate : 0.000013 loss : 0.409260 +[15:22:16.895] iteration 63900 [1218.65 sec]: learning rate : 0.000013 loss : 0.315567 +[15:23:43.360] iteration 64000 [1305.12 sec]: learning rate : 0.000013 loss : 0.568098 +[15:25:09.824] iteration 64100 [1391.58 sec]: learning rate : 0.000013 loss : 0.237587 +[15:26:36.216] iteration 64200 [1477.97 sec]: learning rate : 0.000013 loss : 0.315682 +[15:28:02.658] iteration 64300 [1564.41 sec]: learning rate : 0.000013 loss : 0.484486 +[15:29:29.120] iteration 64400 [1650.87 sec]: learning rate : 0.000013 loss : 0.470376 +[15:30:55.513] iteration 64500 [1737.27 sec]: learning rate : 0.000013 loss : 0.419645 +[15:31:58.609] Epoch 30 Evaluation: +[15:32:47.897] average MSE: 0.03984714671969414 average PSNR: 29.688381022395628 average SSIM: 0.7337293881516405 +[15:33:11.433] iteration 64600 [23.47 sec]: learning rate : 0.000013 loss : 0.642719 +[15:34:37.906] iteration 64700 [109.95 sec]: learning rate : 0.000013 loss : 0.394201 +[15:36:04.268] iteration 64800 [196.31 sec]: learning rate : 0.000013 loss : 0.388845 +[15:37:30.730] iteration 64900 [282.77 sec]: learning rate : 0.000013 loss : 0.316943 +[15:38:57.099] iteration 65000 [369.14 sec]: learning rate : 0.000013 loss : 0.425649 +[15:40:23.500] iteration 65100 [455.54 sec]: learning rate : 0.000013 loss : 0.472899 +[15:41:49.941] iteration 65200 [541.98 sec]: learning rate : 0.000013 loss : 0.386462 +[15:43:16.302] iteration 65300 [628.34 sec]: learning rate : 0.000013 loss : 0.344890 +[15:44:42.684] iteration 65400 [714.72 sec]: learning rate : 0.000013 loss : 0.525556 +[15:46:09.039] iteration 65500 [801.08 sec]: learning rate : 0.000013 loss : 0.635110 +[15:47:35.463] iteration 65600 [887.50 sec]: learning rate : 0.000013 loss : 0.565278 +[15:49:01.876] iteration 65700 [973.92 sec]: learning rate : 0.000013 loss : 0.408144 +[15:50:28.231] iteration 65800 [1060.27 sec]: learning rate : 0.000013 loss : 0.545993 +[15:51:54.716] iteration 65900 [1146.76 sec]: learning rate : 0.000013 loss : 0.495034 +[15:53:21.116] iteration 66000 [1233.16 sec]: learning rate : 0.000013 loss : 0.384706 +[15:54:47.499] iteration 66100 [1319.54 sec]: learning rate : 0.000013 loss : 0.531193 +[15:56:13.926] iteration 66200 [1405.98 sec]: learning rate : 0.000013 loss : 0.350046 +[15:57:40.316] iteration 66300 [1492.36 sec]: learning rate : 0.000013 loss : 0.419068 +[15:59:06.686] iteration 66400 [1578.73 sec]: learning rate : 0.000013 loss : 0.379042 +[16:00:33.105] iteration 66500 [1665.15 sec]: learning rate : 0.000013 loss : 0.560455 +[16:01:59.476] iteration 66600 [1751.52 sec]: learning rate : 0.000013 loss : 0.473330 +[16:02:47.831] Epoch 31 Evaluation: +[16:03:36.941] average MSE: 0.03981137275695801 average PSNR: 29.685770046619368 average SSIM: 0.7330783625730258 +[16:04:15.315] iteration 66700 [38.31 sec]: learning rate : 0.000013 loss : 0.683329 +[16:05:41.745] iteration 66800 [124.74 sec]: learning rate : 0.000013 loss : 0.334276 +[16:07:08.237] iteration 66900 [211.23 sec]: learning rate : 0.000013 loss : 0.427328 +[16:08:34.703] iteration 67000 [297.70 sec]: learning rate : 0.000013 loss : 0.630211 +[16:10:01.150] iteration 67100 [384.15 sec]: learning rate : 0.000013 loss : 0.613792 +[16:11:27.648] iteration 67200 [470.64 sec]: learning rate : 0.000013 loss : 0.391902 +[16:12:54.182] iteration 67300 [557.18 sec]: learning rate : 0.000013 loss : 0.466301 +[16:14:20.684] iteration 67400 [643.68 sec]: learning rate : 0.000013 loss : 0.591091 +[16:15:47.201] iteration 67500 [730.20 sec]: learning rate : 0.000013 loss : 0.333354 +[16:17:13.670] iteration 67600 [816.67 sec]: learning rate : 0.000013 loss : 0.393909 +[16:18:40.181] iteration 67700 [903.18 sec]: learning rate : 0.000013 loss : 0.305483 +[16:20:06.685] iteration 67800 [989.68 sec]: learning rate : 0.000013 loss : 0.450345 +[16:21:33.220] iteration 67900 [1076.22 sec]: learning rate : 0.000013 loss : 0.288258 +[16:22:59.724] iteration 68000 [1162.72 sec]: learning rate : 0.000013 loss : 0.620942 +[16:24:26.193] iteration 68100 [1249.19 sec]: learning rate : 0.000013 loss : 0.523042 +[16:25:52.746] iteration 68200 [1335.74 sec]: learning rate : 0.000013 loss : 0.398665 +[16:27:19.255] iteration 68300 [1422.25 sec]: learning rate : 0.000013 loss : 0.518247 +[16:28:45.722] iteration 68400 [1508.72 sec]: learning rate : 0.000013 loss : 0.615253 +[16:30:12.225] iteration 68500 [1595.22 sec]: learning rate : 0.000013 loss : 0.452468 +[16:31:38.693] iteration 68600 [1681.69 sec]: learning rate : 0.000013 loss : 0.500999 +[16:33:05.188] iteration 68700 [1768.18 sec]: learning rate : 0.000013 loss : 0.423361 +[16:33:38.893] Epoch 32 Evaluation: +[16:34:28.597] average MSE: 0.039714232087135315 average PSNR: 29.70422198013783 average SSIM: 0.7336025108358771 +[16:35:21.729] iteration 68800 [53.09 sec]: learning rate : 0.000013 loss : 0.344594 +[16:36:48.192] iteration 68900 [139.53 sec]: learning rate : 0.000013 loss : 0.993826 +[16:38:14.663] iteration 69000 [226.00 sec]: learning rate : 0.000013 loss : 0.315432 +[16:39:41.101] iteration 69100 [312.44 sec]: learning rate : 0.000013 loss : 0.551249 +[16:41:07.618] iteration 69200 [398.96 sec]: learning rate : 0.000013 loss : 0.260525 +[16:42:34.115] iteration 69300 [485.46 sec]: learning rate : 0.000013 loss : 0.473193 +[16:44:00.608] iteration 69400 [571.95 sec]: learning rate : 0.000013 loss : 0.514966 +[16:45:27.123] iteration 69500 [658.46 sec]: learning rate : 0.000013 loss : 0.296951 +[16:46:53.627] iteration 69600 [744.97 sec]: learning rate : 0.000013 loss : 0.560920 +[16:48:20.084] iteration 69700 [831.42 sec]: learning rate : 0.000013 loss : 0.377786 +[16:49:46.565] iteration 69800 [917.91 sec]: learning rate : 0.000013 loss : 0.546694 +[16:51:13.015] iteration 69900 [1004.35 sec]: learning rate : 0.000013 loss : 0.534010 +[16:52:39.523] iteration 70000 [1090.87 sec]: learning rate : 0.000013 loss : 0.363931 +[16:54:06.028] iteration 70100 [1177.37 sec]: learning rate : 0.000013 loss : 0.506557 +[16:55:32.467] iteration 70200 [1263.81 sec]: learning rate : 0.000013 loss : 0.363019 +[16:56:58.986] iteration 70300 [1350.33 sec]: learning rate : 0.000013 loss : 0.351988 +[16:58:25.514] iteration 70400 [1436.86 sec]: learning rate : 0.000013 loss : 0.491796 +[16:59:51.970] iteration 70500 [1523.31 sec]: learning rate : 0.000013 loss : 0.431326 +[17:01:18.488] iteration 70600 [1609.83 sec]: learning rate : 0.000013 loss : 0.537308 +[17:02:44.932] iteration 70700 [1696.27 sec]: learning rate : 0.000013 loss : 0.533605 +[17:04:11.445] iteration 70800 [1782.79 sec]: learning rate : 0.000013 loss : 0.573845 +[17:04:30.429] Epoch 33 Evaluation: +[17:05:21.735] average MSE: 0.03954209014773369 average PSNR: 29.71781296536392 average SSIM: 0.7337749188597659 +[17:06:29.528] iteration 70900 [67.73 sec]: learning rate : 0.000013 loss : 0.403161 +[17:07:56.015] iteration 71000 [154.22 sec]: learning rate : 0.000013 loss : 0.317139 +[17:09:22.610] iteration 71100 [240.81 sec]: learning rate : 0.000013 loss : 0.573630 +[17:10:49.158] iteration 71200 [327.36 sec]: learning rate : 0.000013 loss : 0.334954 +[17:12:15.640] iteration 71300 [413.84 sec]: learning rate : 0.000013 loss : 0.436905 +[17:13:42.162] iteration 71400 [500.36 sec]: learning rate : 0.000013 loss : 0.258164 +[17:15:08.681] iteration 71500 [586.88 sec]: learning rate : 0.000013 loss : 0.747940 +[17:16:35.169] iteration 71600 [673.37 sec]: learning rate : 0.000013 loss : 0.805186 +[17:18:01.716] iteration 71700 [759.92 sec]: learning rate : 0.000013 loss : 0.274757 +[17:19:28.210] iteration 71800 [846.41 sec]: learning rate : 0.000013 loss : 0.359038 +[17:20:54.768] iteration 71900 [932.97 sec]: learning rate : 0.000013 loss : 0.448475 +[17:22:21.322] iteration 72000 [1019.54 sec]: learning rate : 0.000013 loss : 0.486783 +[17:23:47.841] iteration 72100 [1106.04 sec]: learning rate : 0.000013 loss : 0.457273 +[17:25:14.373] iteration 72200 [1192.58 sec]: learning rate : 0.000013 loss : 0.598864 +[17:26:40.911] iteration 72300 [1279.11 sec]: learning rate : 0.000013 loss : 0.407814 +[17:28:07.413] iteration 72400 [1365.62 sec]: learning rate : 0.000013 loss : 0.371912 +[17:29:34.050] iteration 72500 [1452.25 sec]: learning rate : 0.000013 loss : 0.670269 +[17:31:00.594] iteration 72600 [1538.80 sec]: learning rate : 0.000013 loss : 0.505289 +[17:32:27.083] iteration 72700 [1625.28 sec]: learning rate : 0.000013 loss : 0.438103 +[17:33:53.623] iteration 72800 [1711.83 sec]: learning rate : 0.000013 loss : 0.493669 +[17:35:20.121] iteration 72900 [1798.32 sec]: learning rate : 0.000013 loss : 0.450648 +[17:35:24.429] Epoch 34 Evaluation: +[17:36:14.577] average MSE: 0.03953627496957779 average PSNR: 29.725164875974514 average SSIM: 0.734037489356742 +[17:37:36.931] iteration 73000 [82.29 sec]: learning rate : 0.000013 loss : 0.457524 +[17:39:03.475] iteration 73100 [168.84 sec]: learning rate : 0.000013 loss : 0.587911 +[17:40:29.901] iteration 73200 [255.26 sec]: learning rate : 0.000013 loss : 0.528713 +[17:41:56.386] iteration 73300 [341.75 sec]: learning rate : 0.000013 loss : 0.331248 +[17:43:22.904] iteration 73400 [428.26 sec]: learning rate : 0.000013 loss : 0.419084 +[17:44:49.353] iteration 73500 [514.71 sec]: learning rate : 0.000013 loss : 0.402192 +[17:46:15.850] iteration 73600 [601.21 sec]: learning rate : 0.000013 loss : 0.489265 +[17:47:42.384] iteration 73700 [687.74 sec]: learning rate : 0.000013 loss : 0.481164 +[17:49:08.819] iteration 73800 [774.18 sec]: learning rate : 0.000013 loss : 0.388851 +[17:50:35.336] iteration 73900 [860.70 sec]: learning rate : 0.000013 loss : 0.617080 +[17:52:01.842] iteration 74000 [947.20 sec]: learning rate : 0.000013 loss : 0.415908 +[17:53:28.293] iteration 74100 [1033.65 sec]: learning rate : 0.000013 loss : 0.589983 +[17:54:54.835] iteration 74200 [1120.20 sec]: learning rate : 0.000013 loss : 0.328323 +[17:56:21.380] iteration 74300 [1206.74 sec]: learning rate : 0.000013 loss : 0.493654 +[17:57:47.852] iteration 74400 [1293.21 sec]: learning rate : 0.000013 loss : 0.317698 +[17:59:14.346] iteration 74500 [1379.71 sec]: learning rate : 0.000013 loss : 0.424790 +[18:00:40.868] iteration 74600 [1466.23 sec]: learning rate : 0.000013 loss : 0.630579 +[18:02:07.322] iteration 74700 [1552.68 sec]: learning rate : 0.000013 loss : 0.229838 +[18:03:33.799] iteration 74800 [1639.16 sec]: learning rate : 0.000013 loss : 0.549267 +[18:05:00.254] iteration 74900 [1725.61 sec]: learning rate : 0.000013 loss : 0.331882 +[18:06:16.362] Epoch 35 Evaluation: +[18:07:07.498] average MSE: 0.039555247873067856 average PSNR: 29.718267557757446 average SSIM: 0.7333775498254904 +[18:07:18.089] iteration 75000 [10.53 sec]: learning rate : 0.000013 loss : 0.451392 +[18:08:44.633] iteration 75100 [97.07 sec]: learning rate : 0.000013 loss : 0.330644 +[18:10:11.115] iteration 75200 [183.55 sec]: learning rate : 0.000013 loss : 0.412352 +[18:11:37.627] iteration 75300 [270.07 sec]: learning rate : 0.000013 loss : 0.257509 +[18:13:04.109] iteration 75400 [356.55 sec]: learning rate : 0.000013 loss : 0.522568 +[18:14:30.631] iteration 75500 [443.07 sec]: learning rate : 0.000013 loss : 0.496813 +[18:15:57.146] iteration 75600 [529.59 sec]: learning rate : 0.000013 loss : 0.651547 +[18:17:23.621] iteration 75700 [616.06 sec]: learning rate : 0.000013 loss : 0.437096 +[18:18:50.173] iteration 75800 [702.61 sec]: learning rate : 0.000013 loss : 0.358689 +[18:20:16.730] iteration 75900 [789.17 sec]: learning rate : 0.000013 loss : 0.448452 +[18:21:43.231] iteration 76000 [875.67 sec]: learning rate : 0.000013 loss : 0.394830 +[18:23:09.776] iteration 76100 [962.21 sec]: learning rate : 0.000013 loss : 0.587290 +[18:24:36.268] iteration 76200 [1048.71 sec]: learning rate : 0.000013 loss : 0.543954 +[18:26:02.814] iteration 76300 [1135.26 sec]: learning rate : 0.000013 loss : 0.480307 +[18:27:29.361] iteration 76400 [1221.80 sec]: learning rate : 0.000013 loss : 0.502946 +[18:28:55.857] iteration 76500 [1308.30 sec]: learning rate : 0.000013 loss : 0.672459 +[18:30:22.426] iteration 76600 [1394.86 sec]: learning rate : 0.000013 loss : 0.321023 +[18:31:49.000] iteration 76700 [1481.44 sec]: learning rate : 0.000013 loss : 0.333465 +[18:33:15.524] iteration 76800 [1567.96 sec]: learning rate : 0.000013 loss : 0.294391 +[18:34:42.094] iteration 76900 [1654.53 sec]: learning rate : 0.000013 loss : 0.434940 +[18:36:08.656] iteration 77000 [1741.10 sec]: learning rate : 0.000013 loss : 0.353849 +[18:37:10.054] Epoch 36 Evaluation: +[18:37:59.128] average MSE: 0.03953191265463829 average PSNR: 29.724342725574047 average SSIM: 0.7346808441945455 +[18:38:24.448] iteration 77100 [25.26 sec]: learning rate : 0.000013 loss : 0.510330 +[18:39:51.074] iteration 77200 [111.88 sec]: learning rate : 0.000013 loss : 0.650248 +[18:41:17.696] iteration 77300 [198.51 sec]: learning rate : 0.000013 loss : 0.485292 +[18:42:44.230] iteration 77400 [285.04 sec]: learning rate : 0.000013 loss : 0.519908 +[18:44:10.794] iteration 77500 [371.60 sec]: learning rate : 0.000013 loss : 0.418368 +[18:45:37.389] iteration 77600 [458.20 sec]: learning rate : 0.000013 loss : 0.614133 +[18:47:03.927] iteration 77700 [544.74 sec]: learning rate : 0.000013 loss : 0.429768 +[18:48:30.535] iteration 77800 [631.35 sec]: learning rate : 0.000013 loss : 0.288885 +[18:49:57.076] iteration 77900 [717.89 sec]: learning rate : 0.000013 loss : 0.407142 +[18:51:23.635] iteration 78000 [804.44 sec]: learning rate : 0.000013 loss : 0.295989 +[18:52:50.264] iteration 78100 [891.07 sec]: learning rate : 0.000013 loss : 0.643033 +[18:54:16.800] iteration 78200 [977.61 sec]: learning rate : 0.000013 loss : 0.321205 +[18:55:43.456] iteration 78300 [1064.27 sec]: learning rate : 0.000013 loss : 0.393877 +[18:57:10.100] iteration 78400 [1150.91 sec]: learning rate : 0.000013 loss : 0.479289 +[18:58:36.658] iteration 78500 [1237.47 sec]: learning rate : 0.000013 loss : 0.348587 +[19:00:03.296] iteration 78600 [1324.11 sec]: learning rate : 0.000013 loss : 0.360308 +[19:01:29.910] iteration 78700 [1410.72 sec]: learning rate : 0.000013 loss : 0.597933 +[19:02:56.467] iteration 78800 [1497.28 sec]: learning rate : 0.000013 loss : 0.463994 +[19:04:23.037] iteration 78900 [1583.85 sec]: learning rate : 0.000013 loss : 0.357104 +[19:05:49.587] iteration 79000 [1670.40 sec]: learning rate : 0.000013 loss : 0.400522 +[19:07:16.180] iteration 79100 [1756.99 sec]: learning rate : 0.000013 loss : 0.529239 +[19:08:02.871] Epoch 37 Evaluation: +[19:08:52.320] average MSE: 0.03962287679314613 average PSNR: 29.715643343200583 average SSIM: 0.7339576283804022 +[19:09:32.502] iteration 79200 [40.12 sec]: learning rate : 0.000013 loss : 0.425555 +[19:10:59.071] iteration 79300 [126.69 sec]: learning rate : 0.000013 loss : 0.547022 +[19:12:25.664] iteration 79400 [213.28 sec]: learning rate : 0.000013 loss : 0.482697 +[19:13:52.226] iteration 79500 [299.84 sec]: learning rate : 0.000013 loss : 0.571378 +[19:15:18.735] iteration 79600 [386.35 sec]: learning rate : 0.000013 loss : 0.412372 +[19:16:45.315] iteration 79700 [472.93 sec]: learning rate : 0.000013 loss : 0.818248 +[19:18:11.932] iteration 79800 [559.55 sec]: learning rate : 0.000013 loss : 0.573188 +[19:19:38.470] iteration 79900 [646.09 sec]: learning rate : 0.000013 loss : 0.512814 +[19:21:05.063] iteration 80000 [732.68 sec]: learning rate : 0.000003 loss : 0.265936 +[19:21:05.222] save model to model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion/iter_80000.pth +[19:22:31.747] iteration 80100 [819.37 sec]: learning rate : 0.000006 loss : 0.834989 +[19:23:58.382] iteration 80200 [906.00 sec]: learning rate : 0.000006 loss : 0.481169 +[19:25:25.003] iteration 80300 [992.62 sec]: learning rate : 0.000006 loss : 0.282075 +[19:26:51.545] iteration 80400 [1079.16 sec]: learning rate : 0.000006 loss : 0.316999 +[19:28:18.137] iteration 80500 [1165.75 sec]: learning rate : 0.000006 loss : 0.337826 +[19:29:44.725] iteration 80600 [1252.34 sec]: learning rate : 0.000006 loss : 0.750202 +[19:31:11.242] iteration 80700 [1338.86 sec]: learning rate : 0.000006 loss : 0.520212 +[19:32:37.814] iteration 80800 [1425.43 sec]: learning rate : 0.000006 loss : 0.602146 +[19:34:04.402] iteration 80900 [1512.02 sec]: learning rate : 0.000006 loss : 0.491330 +[19:35:30.899] iteration 81000 [1598.52 sec]: learning rate : 0.000006 loss : 0.512998 +[19:36:57.482] iteration 81100 [1685.10 sec]: learning rate : 0.000006 loss : 0.451538 +[19:38:24.078] iteration 81200 [1771.70 sec]: learning rate : 0.000006 loss : 0.333441 +[19:38:56.079] Epoch 38 Evaluation: +[19:39:45.688] average MSE: 0.03960740938782692 average PSNR: 29.718395650228114 average SSIM: 0.7343289509576589 +[19:40:40.369] iteration 81300 [54.62 sec]: learning rate : 0.000006 loss : 0.548704 +[19:42:06.947] iteration 81400 [141.20 sec]: learning rate : 0.000006 loss : 0.734063 +[19:43:33.425] iteration 81500 [227.67 sec]: learning rate : 0.000006 loss : 0.598093 +[19:44:59.947] iteration 81600 [314.20 sec]: learning rate : 0.000006 loss : 0.961480 +[19:46:26.492] iteration 81700 [400.74 sec]: learning rate : 0.000006 loss : 0.520043 +[19:47:52.954] iteration 81800 [487.20 sec]: learning rate : 0.000006 loss : 0.378037 +[19:49:19.473] iteration 81900 [573.72 sec]: learning rate : 0.000006 loss : 0.634308 +[19:50:46.000] iteration 82000 [660.25 sec]: learning rate : 0.000006 loss : 0.446492 +[19:52:12.460] iteration 82100 [746.71 sec]: learning rate : 0.000006 loss : 0.469385 +[19:53:38.979] iteration 82200 [833.23 sec]: learning rate : 0.000006 loss : 0.587709 +[19:55:05.540] iteration 82300 [919.79 sec]: learning rate : 0.000006 loss : 0.372541 +[19:56:32.042] iteration 82400 [1006.29 sec]: learning rate : 0.000006 loss : 0.527903 +[19:57:58.611] iteration 82500 [1092.86 sec]: learning rate : 0.000006 loss : 0.399382 +[19:59:25.186] iteration 82600 [1179.44 sec]: learning rate : 0.000006 loss : 0.525045 +[20:00:51.696] iteration 82700 [1265.95 sec]: learning rate : 0.000006 loss : 0.468304 +[20:02:18.288] iteration 82800 [1352.54 sec]: learning rate : 0.000006 loss : 0.502704 +[20:03:44.780] iteration 82900 [1439.03 sec]: learning rate : 0.000006 loss : 0.404664 +[20:05:11.360] iteration 83000 [1525.61 sec]: learning rate : 0.000006 loss : 0.440648 +[20:06:37.914] iteration 83100 [1612.16 sec]: learning rate : 0.000006 loss : 0.350009 +[20:08:04.429] iteration 83200 [1698.68 sec]: learning rate : 0.000006 loss : 0.284668 +[20:09:30.996] iteration 83300 [1785.24 sec]: learning rate : 0.000006 loss : 0.580974 +[20:09:48.274] Epoch 39 Evaluation: +[20:10:38.311] average MSE: 0.03937874734401703 average PSNR: 29.749575539214955 average SSIM: 0.7340718740525197 +[20:11:47.874] iteration 83400 [69.50 sec]: learning rate : 0.000006 loss : 0.533884 +[20:13:14.366] iteration 83500 [155.99 sec]: learning rate : 0.000006 loss : 0.565752 +[20:14:40.913] iteration 83600 [242.54 sec]: learning rate : 0.000006 loss : 0.562294 +[20:16:07.429] iteration 83700 [329.05 sec]: learning rate : 0.000006 loss : 0.399434 +[20:17:33.930] iteration 83800 [415.55 sec]: learning rate : 0.000006 loss : 0.469431 +[20:19:00.501] iteration 83900 [502.12 sec]: learning rate : 0.000006 loss : 0.255770 +[20:20:27.090] iteration 84000 [588.71 sec]: learning rate : 0.000006 loss : 0.686349 +[20:21:53.585] iteration 84100 [675.21 sec]: learning rate : 0.000006 loss : 0.859236 +[20:23:20.148] iteration 84200 [761.77 sec]: learning rate : 0.000006 loss : 0.368007 +[20:24:46.695] iteration 84300 [848.32 sec]: learning rate : 0.000006 loss : 0.446550 +[20:26:13.200] iteration 84400 [934.82 sec]: learning rate : 0.000006 loss : 0.695771 +[20:27:39.791] iteration 84500 [1021.42 sec]: learning rate : 0.000006 loss : 0.544175 +[20:29:06.369] iteration 84600 [1107.99 sec]: learning rate : 0.000006 loss : 0.370248 +[20:30:32.869] iteration 84700 [1194.49 sec]: learning rate : 0.000006 loss : 0.518790 +[20:31:59.405] iteration 84800 [1281.03 sec]: learning rate : 0.000006 loss : 0.373519 +[20:33:25.961] iteration 84900 [1367.58 sec]: learning rate : 0.000006 loss : 0.466143 +[20:34:52.480] iteration 85000 [1454.10 sec]: learning rate : 0.000006 loss : 0.443961 +[20:36:19.033] iteration 85100 [1540.66 sec]: learning rate : 0.000006 loss : 0.245171 +[20:37:45.587] iteration 85200 [1627.21 sec]: learning rate : 0.000006 loss : 0.573308 +[20:39:12.081] iteration 85300 [1713.70 sec]: learning rate : 0.000006 loss : 0.434287 +[20:40:38.635] iteration 85400 [1800.26 sec]: learning rate : 0.000006 loss : 0.383230 +[20:40:41.218] Epoch 40 Evaluation: +[20:41:32.580] average MSE: 0.03948330134153366 average PSNR: 29.73136157136514 average SSIM: 0.7340377800921897 +[20:42:56.790] iteration 85500 [84.15 sec]: learning rate : 0.000006 loss : 0.306828 +[20:44:23.248] iteration 85600 [170.61 sec]: learning rate : 0.000006 loss : 0.549524 +[20:45:49.774] iteration 85700 [257.13 sec]: learning rate : 0.000006 loss : 0.643733 +[20:47:16.288] iteration 85800 [343.64 sec]: learning rate : 0.000006 loss : 0.392745 +[20:48:42.750] iteration 85900 [430.11 sec]: learning rate : 0.000006 loss : 0.548011 +[20:50:09.290] iteration 86000 [516.65 sec]: learning rate : 0.000006 loss : 0.603953 +[20:51:35.800] iteration 86100 [603.16 sec]: learning rate : 0.000006 loss : 0.367144 +[20:53:02.331] iteration 86200 [689.69 sec]: learning rate : 0.000006 loss : 0.472085 +[20:54:28.876] iteration 86300 [776.23 sec]: learning rate : 0.000006 loss : 0.447087 +[20:55:55.355] iteration 86400 [862.71 sec]: learning rate : 0.000006 loss : 0.270473 +[20:57:21.886] iteration 86500 [949.24 sec]: learning rate : 0.000006 loss : 0.646266 +[20:58:48.423] iteration 86600 [1035.78 sec]: learning rate : 0.000006 loss : 0.625527 +[21:00:14.931] iteration 86700 [1122.29 sec]: learning rate : 0.000006 loss : 0.253775 +[21:01:41.479] iteration 86800 [1208.84 sec]: learning rate : 0.000006 loss : 0.769632 +[21:03:08.008] iteration 86900 [1295.36 sec]: learning rate : 0.000006 loss : 0.687204 +[21:04:34.588] iteration 87000 [1381.94 sec]: learning rate : 0.000006 loss : 0.270284 +[21:06:01.143] iteration 87100 [1468.50 sec]: learning rate : 0.000006 loss : 0.472685 +[21:07:27.638] iteration 87200 [1554.99 sec]: learning rate : 0.000006 loss : 0.602367 +[21:08:54.199] iteration 87300 [1641.56 sec]: learning rate : 0.000006 loss : 0.383232 +[21:10:20.757] iteration 87400 [1728.11 sec]: learning rate : 0.000006 loss : 0.650823 +[21:11:35.117] Epoch 41 Evaluation: +[21:12:24.836] average MSE: 0.03963155671954155 average PSNR: 29.725902638123 average SSIM: 0.7351198807336559 +[21:12:37.191] iteration 87500 [12.29 sec]: learning rate : 0.000006 loss : 0.679262 +[21:14:03.701] iteration 87600 [98.80 sec]: learning rate : 0.000006 loss : 0.679357 +[21:15:30.169] iteration 87700 [185.27 sec]: learning rate : 0.000006 loss : 0.326561 +[21:16:56.598] iteration 87800 [271.70 sec]: learning rate : 0.000006 loss : 0.379614 +[21:18:23.071] iteration 87900 [358.17 sec]: learning rate : 0.000006 loss : 0.250481 +[21:19:49.576] iteration 88000 [444.68 sec]: learning rate : 0.000006 loss : 0.858552 +[21:21:15.997] iteration 88100 [531.10 sec]: learning rate : 0.000006 loss : 0.305947 +[21:22:42.500] iteration 88200 [617.60 sec]: learning rate : 0.000006 loss : 0.364550 +[21:24:08.926] iteration 88300 [704.03 sec]: learning rate : 0.000006 loss : 0.403696 +[21:25:35.419] iteration 88400 [790.52 sec]: learning rate : 0.000006 loss : 0.291269 +[21:27:01.936] iteration 88500 [877.04 sec]: learning rate : 0.000006 loss : 0.442157 +[21:28:28.391] iteration 88600 [963.49 sec]: learning rate : 0.000006 loss : 0.781586 +[21:29:54.880] iteration 88700 [1049.98 sec]: learning rate : 0.000006 loss : 0.671780 +[21:31:21.342] iteration 88800 [1136.44 sec]: learning rate : 0.000006 loss : 0.486177 +[21:32:47.801] iteration 88900 [1222.90 sec]: learning rate : 0.000006 loss : 0.439288 +[21:34:14.317] iteration 89000 [1309.42 sec]: learning rate : 0.000006 loss : 0.452434 +[21:35:40.794] iteration 89100 [1395.89 sec]: learning rate : 0.000006 loss : 0.466919 +[21:37:07.328] iteration 89200 [1482.43 sec]: learning rate : 0.000006 loss : 0.276566 +[21:38:33.841] iteration 89300 [1568.94 sec]: learning rate : 0.000006 loss : 0.421128 +[21:40:00.289] iteration 89400 [1655.39 sec]: learning rate : 0.000006 loss : 0.377420 +[21:41:26.779] iteration 89500 [1741.88 sec]: learning rate : 0.000006 loss : 0.641131 +[21:42:26.431] Epoch 42 Evaluation: +[21:43:18.165] average MSE: 0.03961680456995964 average PSNR: 29.723651688934495 average SSIM: 0.7344615443160504 +[21:43:45.360] iteration 89600 [27.13 sec]: learning rate : 0.000006 loss : 0.502874 +[21:45:11.848] iteration 89700 [113.62 sec]: learning rate : 0.000006 loss : 0.515168 +[21:46:38.428] iteration 89800 [200.20 sec]: learning rate : 0.000006 loss : 0.470391 +[21:48:04.991] iteration 89900 [286.76 sec]: learning rate : 0.000006 loss : 0.638239 +[21:49:31.498] iteration 90000 [373.27 sec]: learning rate : 0.000006 loss : 0.792757 +[21:50:58.042] iteration 90100 [459.81 sec]: learning rate : 0.000006 loss : 0.349580 +[21:52:24.547] iteration 90200 [546.32 sec]: learning rate : 0.000006 loss : 0.440391 +[21:53:51.058] iteration 90300 [632.83 sec]: learning rate : 0.000006 loss : 0.475181 +[21:55:17.602] iteration 90400 [719.38 sec]: learning rate : 0.000006 loss : 0.379286 +[21:56:44.106] iteration 90500 [805.88 sec]: learning rate : 0.000006 loss : 0.401873 +[21:58:10.734] iteration 90600 [892.51 sec]: learning rate : 0.000006 loss : 0.557009 +[21:59:37.325] iteration 90700 [979.10 sec]: learning rate : 0.000006 loss : 0.675533 +[22:01:03.852] iteration 90800 [1065.62 sec]: learning rate : 0.000006 loss : 0.538489 +[22:02:30.454] iteration 90900 [1152.23 sec]: learning rate : 0.000006 loss : 0.406627 +[22:03:57.066] iteration 91000 [1238.84 sec]: learning rate : 0.000006 loss : 0.336532 +[22:05:23.586] iteration 91100 [1325.36 sec]: learning rate : 0.000006 loss : 0.313466 +[22:06:50.185] iteration 91200 [1411.96 sec]: learning rate : 0.000006 loss : 0.597685 +[22:08:16.724] iteration 91300 [1498.50 sec]: learning rate : 0.000006 loss : 0.463619 +[22:09:43.302] iteration 91400 [1585.07 sec]: learning rate : 0.000006 loss : 0.562245 +[22:11:09.883] iteration 91500 [1671.66 sec]: learning rate : 0.000006 loss : 0.566149 +[22:12:36.447] iteration 91600 [1758.22 sec]: learning rate : 0.000006 loss : 0.936456 +[22:13:21.476] Epoch 43 Evaluation: +[22:14:11.583] average MSE: 0.039545997977256775 average PSNR: 29.72837672538749 average SSIM: 0.7344399688910702 +[22:14:53.352] iteration 91700 [41.71 sec]: learning rate : 0.000006 loss : 0.869373 +[22:16:19.950] iteration 91800 [128.30 sec]: learning rate : 0.000006 loss : 0.573117 +[22:17:46.444] iteration 91900 [214.80 sec]: learning rate : 0.000006 loss : 0.389945 +[22:19:12.976] iteration 92000 [301.33 sec]: learning rate : 0.000006 loss : 0.464698 +[22:20:39.521] iteration 92100 [387.87 sec]: learning rate : 0.000006 loss : 0.656250 +[22:22:06.007] iteration 92200 [474.36 sec]: learning rate : 0.000006 loss : 0.457008 +[22:23:32.556] iteration 92300 [560.91 sec]: learning rate : 0.000006 loss : 0.532187 +[22:24:59.103] iteration 92400 [647.46 sec]: learning rate : 0.000006 loss : 0.530548 +[22:26:25.588] iteration 92500 [733.94 sec]: learning rate : 0.000006 loss : 0.404542 +[22:27:52.159] iteration 92600 [820.51 sec]: learning rate : 0.000006 loss : 0.489227 +[22:29:18.687] iteration 92700 [907.04 sec]: learning rate : 0.000006 loss : 0.649819 +[22:30:45.213] iteration 92800 [993.57 sec]: learning rate : 0.000006 loss : 0.476731 +[22:32:11.793] iteration 92900 [1080.15 sec]: learning rate : 0.000006 loss : 0.323209 +[22:33:38.274] iteration 93000 [1166.63 sec]: learning rate : 0.000006 loss : 0.312852 +[22:35:04.831] iteration 93100 [1253.18 sec]: learning rate : 0.000006 loss : 0.728102 +[22:36:31.341] iteration 93200 [1339.69 sec]: learning rate : 0.000006 loss : 0.554720 +[22:37:57.842] iteration 93300 [1426.20 sec]: learning rate : 0.000006 loss : 0.407933 +[22:39:24.413] iteration 93400 [1512.77 sec]: learning rate : 0.000006 loss : 0.388054 +[22:40:50.950] iteration 93500 [1599.30 sec]: learning rate : 0.000006 loss : 0.479409 +[22:42:17.438] iteration 93600 [1685.79 sec]: learning rate : 0.000006 loss : 0.430466 +[22:43:43.983] iteration 93700 [1772.34 sec]: learning rate : 0.000006 loss : 0.421703 +[22:44:14.228] Epoch 44 Evaluation: +[22:45:04.312] average MSE: 0.039344917982816696 average PSNR: 29.750822057744756 average SSIM: 0.7340749400020316 +[22:46:00.758] iteration 93800 [56.38 sec]: learning rate : 0.000006 loss : 0.547442 +[22:47:27.351] iteration 93900 [142.98 sec]: learning rate : 0.000006 loss : 0.339520 +[22:48:53.890] iteration 94000 [229.52 sec]: learning rate : 0.000006 loss : 0.673648 +[22:50:20.346] iteration 94100 [315.98 sec]: learning rate : 0.000006 loss : 0.463482 +[22:51:46.875] iteration 94200 [402.50 sec]: learning rate : 0.000006 loss : 0.480645 +[22:53:13.391] iteration 94300 [489.02 sec]: learning rate : 0.000006 loss : 0.321315 +[22:54:39.862] iteration 94400 [575.49 sec]: learning rate : 0.000006 loss : 0.648326 +[22:56:06.397] iteration 94500 [662.02 sec]: learning rate : 0.000006 loss : 0.946629 +[22:57:32.867] iteration 94600 [748.49 sec]: learning rate : 0.000006 loss : 0.444372 +[22:58:59.413] iteration 94700 [835.04 sec]: learning rate : 0.000006 loss : 0.843241 +[23:00:25.941] iteration 94800 [921.57 sec]: learning rate : 0.000006 loss : 0.745952 +[23:01:52.415] iteration 94900 [1008.04 sec]: learning rate : 0.000006 loss : 0.370592 +[23:03:18.946] iteration 95000 [1094.57 sec]: learning rate : 0.000006 loss : 0.401136 +[23:04:45.470] iteration 95100 [1181.10 sec]: learning rate : 0.000006 loss : 0.683997 +[23:06:11.986] iteration 95200 [1267.61 sec]: learning rate : 0.000006 loss : 0.294999 +[23:07:38.532] iteration 95300 [1354.16 sec]: learning rate : 0.000006 loss : 0.348935 +[23:09:05.108] iteration 95400 [1440.74 sec]: learning rate : 0.000006 loss : 0.935463 +[23:10:31.652] iteration 95500 [1527.28 sec]: learning rate : 0.000006 loss : 0.519039 +[23:11:58.224] iteration 95600 [1613.85 sec]: learning rate : 0.000006 loss : 0.478777 +[23:13:24.734] iteration 95700 [1700.36 sec]: learning rate : 0.000006 loss : 0.473422 +[23:14:51.279] iteration 95800 [1786.90 sec]: learning rate : 0.000006 loss : 0.391426 +[23:15:06.815] Epoch 45 Evaluation: +[23:15:56.072] average MSE: 0.039515476673841476 average PSNR: 29.73513203902022 average SSIM: 0.734285819731764 +[23:17:07.360] iteration 95900 [71.22 sec]: learning rate : 0.000006 loss : 0.531598 +[23:18:33.913] iteration 96000 [157.78 sec]: learning rate : 0.000006 loss : 0.435336 +[23:20:00.526] iteration 96100 [244.39 sec]: learning rate : 0.000006 loss : 0.482073 +[23:21:27.136] iteration 96200 [331.00 sec]: learning rate : 0.000006 loss : 0.363560 +[23:22:53.665] iteration 96300 [417.53 sec]: learning rate : 0.000006 loss : 0.898153 +[23:24:20.298] iteration 96400 [504.16 sec]: learning rate : 0.000006 loss : 0.441027 +[23:25:46.939] iteration 96500 [590.80 sec]: learning rate : 0.000006 loss : 0.480608 +[23:27:13.499] iteration 96600 [677.36 sec]: learning rate : 0.000006 loss : 0.354588 +[23:28:40.137] iteration 96700 [764.00 sec]: learning rate : 0.000006 loss : 0.343510 +[23:30:06.694] iteration 96800 [850.56 sec]: learning rate : 0.000006 loss : 0.430002 +[23:31:33.312] iteration 96900 [937.18 sec]: learning rate : 0.000006 loss : 0.946633 +[23:32:59.919] iteration 97000 [1023.78 sec]: learning rate : 0.000006 loss : 0.462090 +[23:34:26.477] iteration 97100 [1110.34 sec]: learning rate : 0.000006 loss : 0.616307 +[23:35:53.092] iteration 97200 [1196.96 sec]: learning rate : 0.000006 loss : 0.489667 +[23:37:19.710] iteration 97300 [1283.57 sec]: learning rate : 0.000006 loss : 0.583260 +[23:38:46.273] iteration 97400 [1370.14 sec]: learning rate : 0.000006 loss : 0.791832 +[23:40:12.931] iteration 97500 [1456.80 sec]: learning rate : 0.000006 loss : 0.415731 +[23:41:39.537] iteration 97600 [1543.40 sec]: learning rate : 0.000006 loss : 0.271638 +[23:43:06.102] iteration 97700 [1629.97 sec]: learning rate : 0.000006 loss : 0.402323 +[23:44:32.717] iteration 97800 [1716.58 sec]: learning rate : 0.000006 loss : 0.181545 +[23:45:59.356] iteration 97900 [1803.22 sec]: learning rate : 0.000006 loss : 0.290851 +[23:46:00.208] Epoch 46 Evaluation: +[23:46:52.142] average MSE: 0.039540331810712814 average PSNR: 29.735238433119928 average SSIM: 0.734860555301494 +[23:48:17.982] iteration 98000 [85.78 sec]: learning rate : 0.000006 loss : 0.371397 +[23:49:44.575] iteration 98100 [172.37 sec]: learning rate : 0.000006 loss : 0.482359 +[23:51:11.084] iteration 98200 [258.88 sec]: learning rate : 0.000006 loss : 0.423288 +[23:52:37.652] iteration 98300 [345.45 sec]: learning rate : 0.000006 loss : 0.384217 +[23:54:04.230] iteration 98400 [432.02 sec]: learning rate : 0.000006 loss : 0.407239 +[23:55:30.699] iteration 98500 [518.49 sec]: learning rate : 0.000006 loss : 0.558223 +[23:56:57.261] iteration 98600 [605.06 sec]: learning rate : 0.000006 loss : 0.451852 +[23:58:23.802] iteration 98700 [691.62 sec]: learning rate : 0.000006 loss : 0.370123 +[23:59:50.285] iteration 98800 [778.08 sec]: learning rate : 0.000006 loss : 0.732954 +[00:01:16.771] iteration 98900 [864.57 sec]: learning rate : 0.000006 loss : 0.326492 +[00:02:43.272] iteration 99000 [951.07 sec]: learning rate : 0.000006 loss : 0.664928 +[00:04:09.806] iteration 99100 [1037.60 sec]: learning rate : 0.000006 loss : 0.456237 +[00:05:36.392] iteration 99200 [1124.19 sec]: learning rate : 0.000006 loss : 0.279702 +[00:07:02.891] iteration 99300 [1210.69 sec]: learning rate : 0.000006 loss : 0.536575 +[00:08:29.436] iteration 99400 [1297.23 sec]: learning rate : 0.000006 loss : 1.016204 +[00:09:55.970] iteration 99500 [1383.76 sec]: learning rate : 0.000006 loss : 0.281312 +[00:11:22.455] iteration 99600 [1470.25 sec]: learning rate : 0.000006 loss : 0.353236 +[00:12:48.991] iteration 99700 [1556.79 sec]: learning rate : 0.000006 loss : 0.478901 +[00:14:15.547] iteration 99800 [1643.34 sec]: learning rate : 0.000006 loss : 0.406837 +[00:15:42.059] iteration 99900 [1729.85 sec]: learning rate : 0.000006 loss : 0.618581 +[00:16:54.750] Epoch 47 Evaluation: +[00:17:44.650] average MSE: 0.039397966116666794 average PSNR: 29.748059872230638 average SSIM: 0.7343177661879344 +[00:17:58.710] iteration 100000 [14.00 sec]: learning rate : 0.000002 loss : 0.603931 +[00:17:58.904] save model to model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion/iter_100000.pth +[00:17:59.751] Epoch 48 Evaluation: +[00:18:50.256] average MSE: 0.03947655111551285 average PSNR: 29.73539868306476 average SSIM: 0.7346509420864041 +[00:18:50.527] save model to model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion/iter_100000.pth diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion/log/events.out.tfevents.1752647839.GCRSANDBOX133 b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion/log/events.out.tfevents.1752647839.GCRSANDBOX133 new file mode 100644 index 0000000000000000000000000000000000000000..37c2eac0bc06ec4a644786844d3bb9888d6c8977 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time_no_distortion/log/events.out.tfevents.1752647839.GCRSANDBOX133 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f7162d75cc2ae4b6790f1a8ed1cd4962555c9784071f12d2593e17f3fbbad4c +size 40 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/best_checkpoint.pth b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/best_checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..4e46df8452259929ee2d695124ac9d74215bc236 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/best_checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0cf872bad2372173060af4fa7722c0cf18b03c829b56ec705bf447cacc3396ac +size 56614874 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/log.txt b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..4cb408e8585382ffb14cf50f39bbbb62f2126c3a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/log.txt @@ -0,0 +1,1105 @@ +[23:35:01.329] Namespace(root_path='/home/v-qichen3/MRI_recon/data/fastmri', MRIDOWN='8X', low_field_SNR=0, phase='train', gpu='1', exp='FSMNet_fastmri_8x', max_iterations=100000, batch_size=4, base_lr=0.0001, seed=1337, resume=None, relation_consistency='False', clip_grad='True', norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, snapshot_path='None', rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='equispaced', CENTER_FRACTIONS=[0.04], ACCELERATIONS=[8], num_timesteps=5, image_size=320, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[23:36:31.178] iteration 100 [87.20 sec]: learning rate : 0.000100 loss : 0.858667 +[23:37:57.302] iteration 200 [173.32 sec]: learning rate : 0.000100 loss : 0.927446 +[23:39:23.435] iteration 300 [259.45 sec]: learning rate : 0.000100 loss : 0.558551 +[23:40:49.686] iteration 400 [345.71 sec]: learning rate : 0.000100 loss : 0.609569 +[23:42:16.015] iteration 500 [432.03 sec]: learning rate : 0.000100 loss : 0.662211 +[23:43:42.475] iteration 600 [518.50 sec]: learning rate : 0.000100 loss : 0.619656 +[23:45:09.035] iteration 700 [605.05 sec]: learning rate : 0.000100 loss : 1.008130 +[23:46:35.585] iteration 800 [691.60 sec]: learning rate : 0.000100 loss : 0.640344 +[23:48:02.208] iteration 900 [778.23 sec]: learning rate : 0.000100 loss : 0.962376 +[23:49:28.860] iteration 1000 [864.88 sec]: learning rate : 0.000100 loss : 1.721478 +[23:50:55.557] iteration 1100 [951.58 sec]: learning rate : 0.000100 loss : 0.646592 +[23:52:22.219] iteration 1200 [1038.24 sec]: learning rate : 0.000100 loss : 0.948289 +[23:53:48.867] iteration 1300 [1124.89 sec]: learning rate : 0.000100 loss : 0.602626 +[23:55:15.601] iteration 1400 [1211.62 sec]: learning rate : 0.000100 loss : 0.845809 +[23:56:42.343] iteration 1500 [1298.36 sec]: learning rate : 0.000100 loss : 0.911223 +[23:58:08.992] iteration 1600 [1385.01 sec]: learning rate : 0.000100 loss : 0.967551 +[23:59:35.746] iteration 1700 [1471.76 sec]: learning rate : 0.000100 loss : 0.680598 +[00:01:02.500] iteration 1800 [1558.52 sec]: learning rate : 0.000100 loss : 0.689657 +[00:02:29.215] iteration 1900 [1645.23 sec]: learning rate : 0.000100 loss : 0.651348 +[00:03:56.009] iteration 2000 [1732.03 sec]: learning rate : 0.000100 loss : 0.560522 +[00:05:08.044] Epoch 0 Evaluation: +[00:05:58.438] average MSE: 0.06918524950742722 average PSNR: 26.926520795563402 average SSIM: 0.5449168961517686 +[00:06:13.443] iteration 2100 [14.94 sec]: learning rate : 0.000100 loss : 0.560556 +[00:07:40.271] iteration 2200 [101.77 sec]: learning rate : 0.000100 loss : 0.841981 +[00:09:07.054] iteration 2300 [188.55 sec]: learning rate : 0.000100 loss : 1.303490 +[00:10:33.789] iteration 2400 [275.29 sec]: learning rate : 0.000100 loss : 0.843573 +[00:12:00.633] iteration 2500 [362.13 sec]: learning rate : 0.000100 loss : 0.643513 +[00:13:27.470] iteration 2600 [448.97 sec]: learning rate : 0.000100 loss : 1.028987 +[00:14:54.235] iteration 2700 [535.73 sec]: learning rate : 0.000100 loss : 0.940151 +[00:16:21.055] iteration 2800 [622.55 sec]: learning rate : 0.000100 loss : 0.679206 +[00:17:47.880] iteration 2900 [709.38 sec]: learning rate : 0.000100 loss : 0.875939 +[00:19:14.659] iteration 3000 [796.16 sec]: learning rate : 0.000100 loss : 0.845913 +[00:20:41.509] iteration 3100 [883.01 sec]: learning rate : 0.000100 loss : 1.109597 +[00:22:08.296] iteration 3200 [969.79 sec]: learning rate : 0.000100 loss : 0.510848 +[00:23:35.067] iteration 3300 [1056.57 sec]: learning rate : 0.000100 loss : 0.898591 +[00:25:01.921] iteration 3400 [1143.42 sec]: learning rate : 0.000100 loss : 0.627269 +[00:26:28.691] iteration 3500 [1230.19 sec]: learning rate : 0.000100 loss : 0.793013 +[00:27:55.480] iteration 3600 [1316.98 sec]: learning rate : 0.000100 loss : 0.614892 +[00:29:22.240] iteration 3700 [1403.74 sec]: learning rate : 0.000100 loss : 0.611904 +[00:30:49.077] iteration 3800 [1490.58 sec]: learning rate : 0.000100 loss : 0.624907 +[00:32:15.858] iteration 3900 [1577.36 sec]: learning rate : 0.000100 loss : 0.653131 +[00:33:42.605] iteration 4000 [1664.10 sec]: learning rate : 0.000100 loss : 0.587875 +[00:35:09.421] iteration 4100 [1750.92 sec]: learning rate : 0.000100 loss : 0.578590 +[00:36:06.654] Epoch 1 Evaluation: +[00:36:57.130] average MSE: 0.0645923912525177 average PSNR: 27.26899708353102 average SSIM: 0.5674550888687998 +[00:37:26.880] iteration 4200 [29.69 sec]: learning rate : 0.000100 loss : 1.218155 +[00:38:53.702] iteration 4300 [116.51 sec]: learning rate : 0.000100 loss : 0.647569 +[00:40:20.478] iteration 4400 [203.28 sec]: learning rate : 0.000100 loss : 0.667184 +[00:41:47.193] iteration 4500 [290.00 sec]: learning rate : 0.000100 loss : 0.724785 +[00:43:14.018] iteration 4600 [376.83 sec]: learning rate : 0.000100 loss : 0.789799 +[00:44:40.780] iteration 4700 [463.59 sec]: learning rate : 0.000100 loss : 0.495913 +[00:46:07.528] iteration 4800 [550.33 sec]: learning rate : 0.000100 loss : 0.655529 +[00:47:34.289] iteration 4900 [637.10 sec]: learning rate : 0.000100 loss : 0.781471 +[00:49:01.047] iteration 5000 [723.85 sec]: learning rate : 0.000100 loss : 1.041412 +[00:50:27.755] iteration 5100 [810.56 sec]: learning rate : 0.000100 loss : 0.696230 +[00:51:54.492] iteration 5200 [897.30 sec]: learning rate : 0.000100 loss : 0.566087 +[00:53:21.152] iteration 5300 [983.96 sec]: learning rate : 0.000100 loss : 0.332724 +[00:54:47.869] iteration 5400 [1070.68 sec]: learning rate : 0.000100 loss : 0.506632 +[00:56:14.575] iteration 5500 [1157.38 sec]: learning rate : 0.000100 loss : 0.752759 +[00:57:41.236] iteration 5600 [1244.04 sec]: learning rate : 0.000100 loss : 0.425816 +[00:59:07.924] iteration 5700 [1330.73 sec]: learning rate : 0.000100 loss : 0.698463 +[01:00:34.571] iteration 5800 [1417.38 sec]: learning rate : 0.000100 loss : 0.446285 +[01:02:01.278] iteration 5900 [1504.08 sec]: learning rate : 0.000100 loss : 0.807647 +[01:03:27.985] iteration 6000 [1590.79 sec]: learning rate : 0.000100 loss : 0.577452 +[01:04:54.634] iteration 6100 [1677.44 sec]: learning rate : 0.000100 loss : 0.411280 +[01:06:21.317] iteration 6200 [1764.12 sec]: learning rate : 0.000100 loss : 0.594680 +[01:07:03.750] Epoch 2 Evaluation: +[01:07:53.961] average MSE: 0.06299909204244614 average PSNR: 27.395351055082518 average SSIM: 0.568226189659116 +[01:08:38.468] iteration 6300 [44.44 sec]: learning rate : 0.000100 loss : 0.939943 +[01:10:05.032] iteration 6400 [131.01 sec]: learning rate : 0.000100 loss : 0.605865 +[01:11:31.683] iteration 6500 [217.66 sec]: learning rate : 0.000100 loss : 0.562848 +[01:12:58.364] iteration 6600 [304.34 sec]: learning rate : 0.000100 loss : 0.625156 +[01:14:24.942] iteration 6700 [390.92 sec]: learning rate : 0.000100 loss : 0.497251 +[01:15:51.601] iteration 6800 [477.58 sec]: learning rate : 0.000100 loss : 0.407128 +[01:17:18.259] iteration 6900 [564.24 sec]: learning rate : 0.000100 loss : 0.743780 +[01:18:44.857] iteration 7000 [650.83 sec]: learning rate : 0.000100 loss : 0.902790 +[01:20:11.518] iteration 7100 [737.49 sec]: learning rate : 0.000100 loss : 0.752017 +[01:21:38.153] iteration 7200 [824.13 sec]: learning rate : 0.000100 loss : 0.532995 +[01:23:04.718] iteration 7300 [910.69 sec]: learning rate : 0.000100 loss : 0.347972 +[01:24:31.305] iteration 7400 [997.28 sec]: learning rate : 0.000100 loss : 0.704105 +[01:25:57.872] iteration 7500 [1083.85 sec]: learning rate : 0.000100 loss : 0.358001 +[01:27:24.506] iteration 7600 [1170.48 sec]: learning rate : 0.000100 loss : 0.587167 +[01:28:51.080] iteration 7700 [1257.06 sec]: learning rate : 0.000100 loss : 0.741259 +[01:30:17.674] iteration 7800 [1343.65 sec]: learning rate : 0.000100 loss : 0.678814 +[01:31:44.289] iteration 7900 [1430.26 sec]: learning rate : 0.000100 loss : 0.757982 +[01:33:10.884] iteration 8000 [1516.86 sec]: learning rate : 0.000100 loss : 0.921447 +[01:34:37.510] iteration 8100 [1603.49 sec]: learning rate : 0.000100 loss : 0.823206 +[01:36:04.161] iteration 8200 [1690.14 sec]: learning rate : 0.000100 loss : 0.505160 +[01:37:30.737] iteration 8300 [1776.71 sec]: learning rate : 0.000100 loss : 0.757268 +[01:37:58.407] Epoch 3 Evaluation: +[01:38:51.008] average MSE: 0.060989197343587875 average PSNR: 27.578989579134344 average SSIM: 0.57620086621237 +[01:39:50.198] iteration 8400 [59.13 sec]: learning rate : 0.000100 loss : 0.532949 +[01:41:16.846] iteration 8500 [145.77 sec]: learning rate : 0.000100 loss : 0.562594 +[01:42:43.360] iteration 8600 [232.29 sec]: learning rate : 0.000100 loss : 0.893589 +[01:44:09.937] iteration 8700 [318.87 sec]: learning rate : 0.000100 loss : 0.564883 +[01:45:36.478] iteration 8800 [405.41 sec]: learning rate : 0.000100 loss : 0.678564 +[01:47:03.079] iteration 8900 [492.01 sec]: learning rate : 0.000100 loss : 0.719227 +[01:48:29.670] iteration 9000 [578.60 sec]: learning rate : 0.000100 loss : 0.616028 +[01:49:56.211] iteration 9100 [665.14 sec]: learning rate : 0.000100 loss : 0.466718 +[01:51:22.830] iteration 9200 [751.76 sec]: learning rate : 0.000100 loss : 0.577804 +[01:52:49.372] iteration 9300 [838.30 sec]: learning rate : 0.000100 loss : 0.446547 +[01:54:15.979] iteration 9400 [924.91 sec]: learning rate : 0.000100 loss : 0.608415 +[01:55:42.573] iteration 9500 [1011.50 sec]: learning rate : 0.000100 loss : 0.430119 +[01:57:09.130] iteration 9600 [1098.06 sec]: learning rate : 0.000100 loss : 0.695462 +[01:58:35.727] iteration 9700 [1184.66 sec]: learning rate : 0.000100 loss : 0.644282 +[02:00:02.354] iteration 9800 [1271.28 sec]: learning rate : 0.000100 loss : 0.504093 +[02:01:28.916] iteration 9900 [1357.85 sec]: learning rate : 0.000100 loss : 0.847476 +[02:02:55.539] iteration 10000 [1444.47 sec]: learning rate : 0.000100 loss : 0.562864 +[02:04:22.093] iteration 10100 [1531.02 sec]: learning rate : 0.000100 loss : 0.481930 +[02:05:48.720] iteration 10200 [1617.65 sec]: learning rate : 0.000100 loss : 0.432065 +[02:07:15.326] iteration 10300 [1704.26 sec]: learning rate : 0.000100 loss : 0.388059 +[02:08:41.870] iteration 10400 [1790.80 sec]: learning rate : 0.000100 loss : 0.670896 +[02:08:54.834] Epoch 4 Evaluation: +[02:09:45.134] average MSE: 0.06072581186890602 average PSNR: 27.61362641551422 average SSIM: 0.5799725985447438 +[02:10:59.050] iteration 10500 [73.85 sec]: learning rate : 0.000100 loss : 0.971102 +[02:12:25.653] iteration 10600 [160.46 sec]: learning rate : 0.000100 loss : 0.606781 +[02:13:52.182] iteration 10700 [246.99 sec]: learning rate : 0.000100 loss : 0.938817 +[02:15:18.819] iteration 10800 [333.62 sec]: learning rate : 0.000100 loss : 0.396632 +[02:16:45.377] iteration 10900 [420.18 sec]: learning rate : 0.000100 loss : 0.744601 +[02:18:11.985] iteration 11000 [506.79 sec]: learning rate : 0.000100 loss : 0.689792 +[02:19:38.581] iteration 11100 [593.38 sec]: learning rate : 0.000100 loss : 0.821196 +[02:21:05.157] iteration 11200 [679.96 sec]: learning rate : 0.000100 loss : 0.603405 +[02:22:31.816] iteration 11300 [766.62 sec]: learning rate : 0.000100 loss : 0.974161 +[02:23:58.401] iteration 11400 [853.20 sec]: learning rate : 0.000100 loss : 0.640970 +[02:25:25.088] iteration 11500 [939.89 sec]: learning rate : 0.000100 loss : 0.575384 +[02:26:51.780] iteration 11600 [1026.58 sec]: learning rate : 0.000100 loss : 0.660156 +[02:28:18.386] iteration 11700 [1113.19 sec]: learning rate : 0.000100 loss : 1.022221 +[02:29:45.097] iteration 11800 [1199.90 sec]: learning rate : 0.000100 loss : 0.433081 +[02:31:11.778] iteration 11900 [1286.58 sec]: learning rate : 0.000100 loss : 0.582857 +[02:32:38.404] iteration 12000 [1373.21 sec]: learning rate : 0.000100 loss : 0.532787 +[02:34:05.084] iteration 12100 [1459.89 sec]: learning rate : 0.000100 loss : 0.625810 +[02:35:31.716] iteration 12200 [1546.52 sec]: learning rate : 0.000100 loss : 0.897206 +[02:36:58.427] iteration 12300 [1633.23 sec]: learning rate : 0.000100 loss : 0.644684 +[02:38:25.140] iteration 12400 [1719.94 sec]: learning rate : 0.000100 loss : 0.643566 +[02:39:50.024] Epoch 5 Evaluation: +[02:40:40.348] average MSE: 0.06046953797340393 average PSNR: 27.640373957573683 average SSIM: 0.5819115791519055 +[02:40:42.326] iteration 12500 [1.91 sec]: learning rate : 0.000100 loss : 0.598356 +[02:42:09.043] iteration 12600 [88.63 sec]: learning rate : 0.000100 loss : 0.551762 +[02:43:35.706] iteration 12700 [175.30 sec]: learning rate : 0.000100 loss : 0.420289 +[02:45:02.352] iteration 12800 [261.94 sec]: learning rate : 0.000100 loss : 0.618113 +[02:46:29.049] iteration 12900 [348.64 sec]: learning rate : 0.000100 loss : 0.541306 +[02:47:55.712] iteration 13000 [435.30 sec]: learning rate : 0.000100 loss : 0.573243 +[02:49:22.440] iteration 13100 [522.03 sec]: learning rate : 0.000100 loss : 0.730630 +[02:50:49.174] iteration 13200 [608.76 sec]: learning rate : 0.000100 loss : 0.478985 +[02:52:15.851] iteration 13300 [695.44 sec]: learning rate : 0.000100 loss : 0.867212 +[02:53:42.559] iteration 13400 [782.15 sec]: learning rate : 0.000100 loss : 0.824808 +[02:55:09.309] iteration 13500 [868.90 sec]: learning rate : 0.000100 loss : 0.778700 +[02:56:36.015] iteration 13600 [955.61 sec]: learning rate : 0.000100 loss : 0.706907 +[02:58:02.778] iteration 13700 [1042.37 sec]: learning rate : 0.000100 loss : 0.380478 +[02:59:29.533] iteration 13800 [1129.12 sec]: learning rate : 0.000100 loss : 0.471252 +[03:00:56.269] iteration 13900 [1215.86 sec]: learning rate : 0.000100 loss : 0.500590 +[03:02:23.081] iteration 14000 [1302.67 sec]: learning rate : 0.000100 loss : 0.961400 +[03:03:49.840] iteration 14100 [1389.43 sec]: learning rate : 0.000100 loss : 0.295357 +[03:05:16.589] iteration 14200 [1476.18 sec]: learning rate : 0.000100 loss : 0.491226 +[03:06:43.388] iteration 14300 [1562.98 sec]: learning rate : 0.000100 loss : 0.734794 +[03:08:10.170] iteration 14400 [1649.76 sec]: learning rate : 0.000100 loss : 0.660113 +[03:09:36.945] iteration 14500 [1736.53 sec]: learning rate : 0.000100 loss : 0.502851 +[03:10:47.305] Epoch 6 Evaluation: +[03:11:37.799] average MSE: 0.05951939895749092 average PSNR: 27.725434781596196 average SSIM: 0.58324215818431 +[03:11:54.508] iteration 14600 [16.65 sec]: learning rate : 0.000100 loss : 0.686010 +[03:13:21.378] iteration 14700 [103.52 sec]: learning rate : 0.000100 loss : 0.985802 +[03:14:48.124] iteration 14800 [190.26 sec]: learning rate : 0.000100 loss : 0.833239 +[03:16:14.899] iteration 14900 [277.04 sec]: learning rate : 0.000100 loss : 0.937087 +[03:17:41.663] iteration 15000 [363.80 sec]: learning rate : 0.000100 loss : 0.887721 +[03:19:08.480] iteration 15100 [450.62 sec]: learning rate : 0.000100 loss : 0.690313 +[03:20:35.237] iteration 15200 [537.37 sec]: learning rate : 0.000100 loss : 0.730369 +[03:22:01.994] iteration 15300 [624.13 sec]: learning rate : 0.000100 loss : 0.604261 +[03:23:28.810] iteration 15400 [710.95 sec]: learning rate : 0.000100 loss : 1.303327 +[03:24:55.558] iteration 15500 [797.70 sec]: learning rate : 0.000100 loss : 0.579344 +[03:26:22.362] iteration 15600 [884.50 sec]: learning rate : 0.000100 loss : 0.680411 +[03:27:49.175] iteration 15700 [971.31 sec]: learning rate : 0.000100 loss : 0.708070 +[03:29:15.935] iteration 15800 [1058.07 sec]: learning rate : 0.000100 loss : 0.756683 +[03:30:42.738] iteration 15900 [1144.87 sec]: learning rate : 0.000100 loss : 0.602384 +[03:32:09.569] iteration 16000 [1231.71 sec]: learning rate : 0.000100 loss : 0.776887 +[03:33:36.337] iteration 16100 [1318.47 sec]: learning rate : 0.000100 loss : 0.484428 +[03:35:03.161] iteration 16200 [1405.30 sec]: learning rate : 0.000100 loss : 0.441535 +[03:36:29.916] iteration 16300 [1492.05 sec]: learning rate : 0.000100 loss : 0.518051 +[03:37:56.744] iteration 16400 [1578.88 sec]: learning rate : 0.000100 loss : 0.518025 +[03:39:23.579] iteration 16500 [1665.72 sec]: learning rate : 0.000100 loss : 0.518725 +[03:40:50.332] iteration 16600 [1752.47 sec]: learning rate : 0.000100 loss : 0.662335 +[03:41:45.917] Epoch 7 Evaluation: +[03:42:38.599] average MSE: 0.05894581973552704 average PSNR: 27.783003727613945 average SSIM: 0.5879513254080811 +[03:43:10.076] iteration 16700 [31.42 sec]: learning rate : 0.000100 loss : 0.782939 +[03:44:36.962] iteration 16800 [118.30 sec]: learning rate : 0.000100 loss : 0.558980 +[03:46:03.705] iteration 16900 [205.04 sec]: learning rate : 0.000100 loss : 0.782719 +[03:47:30.533] iteration 17000 [291.87 sec]: learning rate : 0.000100 loss : 0.597183 +[03:48:57.299] iteration 17100 [378.64 sec]: learning rate : 0.000100 loss : 0.839008 +[03:50:24.090] iteration 17200 [465.43 sec]: learning rate : 0.000100 loss : 0.536416 +[03:51:50.916] iteration 17300 [552.25 sec]: learning rate : 0.000100 loss : 0.895383 +[03:53:17.688] iteration 17400 [639.03 sec]: learning rate : 0.000100 loss : 0.778054 +[03:54:44.482] iteration 17500 [725.82 sec]: learning rate : 0.000100 loss : 0.875515 +[03:56:11.294] iteration 17600 [812.63 sec]: learning rate : 0.000100 loss : 0.833459 +[03:57:38.047] iteration 17700 [899.38 sec]: learning rate : 0.000100 loss : 0.659315 +[03:59:04.851] iteration 17800 [986.19 sec]: learning rate : 0.000100 loss : 0.827427 +[04:00:31.681] iteration 17900 [1073.02 sec]: learning rate : 0.000100 loss : 0.448824 +[04:01:58.421] iteration 18000 [1159.76 sec]: learning rate : 0.000100 loss : 0.595526 +[04:03:25.269] iteration 18100 [1246.61 sec]: learning rate : 0.000100 loss : 0.425040 +[04:04:52.042] iteration 18200 [1333.38 sec]: learning rate : 0.000100 loss : 0.458980 +[04:06:18.867] iteration 18300 [1420.20 sec]: learning rate : 0.000100 loss : 0.652044 +[04:07:45.648] iteration 18400 [1506.99 sec]: learning rate : 0.000100 loss : 0.982686 +[04:09:12.397] iteration 18500 [1593.73 sec]: learning rate : 0.000100 loss : 0.590132 +[04:10:39.184] iteration 18600 [1680.52 sec]: learning rate : 0.000100 loss : 0.543913 +[04:12:05.968] iteration 18700 [1767.31 sec]: learning rate : 0.000100 loss : 0.776664 +[04:12:46.703] Epoch 8 Evaluation: +[04:13:37.884] average MSE: 0.0577298104763031 average PSNR: 27.914652241920137 average SSIM: 0.5911181278037124 +[04:14:24.113] iteration 18800 [46.17 sec]: learning rate : 0.000100 loss : 0.491866 +[04:15:50.926] iteration 18900 [132.98 sec]: learning rate : 0.000100 loss : 0.629464 +[04:17:17.732] iteration 19000 [219.79 sec]: learning rate : 0.000100 loss : 0.577699 +[04:18:44.466] iteration 19100 [306.52 sec]: learning rate : 0.000100 loss : 0.494681 +[04:20:11.252] iteration 19200 [393.30 sec]: learning rate : 0.000100 loss : 0.573901 +[04:21:38.022] iteration 19300 [480.08 sec]: learning rate : 0.000100 loss : 0.601889 +[04:23:04.728] iteration 19400 [566.78 sec]: learning rate : 0.000100 loss : 0.693935 +[04:24:31.518] iteration 19500 [653.57 sec]: learning rate : 0.000100 loss : 0.706639 +[04:25:58.224] iteration 19600 [740.28 sec]: learning rate : 0.000100 loss : 0.677227 +[04:27:24.968] iteration 19700 [827.02 sec]: learning rate : 0.000100 loss : 0.727638 +[04:28:51.718] iteration 19800 [913.77 sec]: learning rate : 0.000100 loss : 0.662169 +[04:30:18.404] iteration 19900 [1000.46 sec]: learning rate : 0.000100 loss : 0.516716 +[04:31:45.161] iteration 20000 [1087.21 sec]: learning rate : 0.000025 loss : 0.615844 +[04:31:45.329] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_20000.pth +[04:33:12.103] iteration 20100 [1174.16 sec]: learning rate : 0.000050 loss : 0.738200 +[04:34:38.840] iteration 20200 [1260.89 sec]: learning rate : 0.000050 loss : 0.898630 +[04:36:05.635] iteration 20300 [1347.69 sec]: learning rate : 0.000050 loss : 0.877526 +[04:37:32.411] iteration 20400 [1434.46 sec]: learning rate : 0.000050 loss : 0.547199 +[04:38:59.139] iteration 20500 [1521.19 sec]: learning rate : 0.000050 loss : 0.606744 +[04:40:25.889] iteration 20600 [1607.94 sec]: learning rate : 0.000050 loss : 0.847707 +[04:41:52.662] iteration 20700 [1694.71 sec]: learning rate : 0.000050 loss : 0.428505 +[04:43:19.365] iteration 20800 [1781.42 sec]: learning rate : 0.000050 loss : 0.682662 +[04:43:45.401] Epoch 9 Evaluation: +[04:44:37.029] average MSE: 0.05770213529467583 average PSNR: 27.92126583181447 average SSIM: 0.5929080617028896 +[04:45:37.925] iteration 20900 [60.83 sec]: learning rate : 0.000050 loss : 0.681922 +[04:47:04.710] iteration 21000 [147.62 sec]: learning rate : 0.000050 loss : 0.489073 +[04:48:31.381] iteration 21100 [234.29 sec]: learning rate : 0.000050 loss : 0.692695 +[04:49:58.122] iteration 21200 [321.03 sec]: learning rate : 0.000050 loss : 0.617892 +[04:51:24.797] iteration 21300 [407.71 sec]: learning rate : 0.000050 loss : 0.696361 +[04:52:51.523] iteration 21400 [494.44 sec]: learning rate : 0.000050 loss : 0.633455 +[04:54:18.255] iteration 21500 [581.16 sec]: learning rate : 0.000050 loss : 0.421460 +[04:55:44.935] iteration 21600 [667.84 sec]: learning rate : 0.000050 loss : 0.656731 +[04:57:11.658] iteration 21700 [754.57 sec]: learning rate : 0.000050 loss : 0.421824 +[04:58:38.405] iteration 21800 [841.31 sec]: learning rate : 0.000050 loss : 0.606716 +[05:00:05.052] iteration 21900 [927.96 sec]: learning rate : 0.000050 loss : 0.701760 +[05:01:31.808] iteration 22000 [1014.72 sec]: learning rate : 0.000050 loss : 0.405739 +[05:02:58.521] iteration 22100 [1101.43 sec]: learning rate : 0.000050 loss : 0.579437 +[05:04:25.220] iteration 22200 [1188.13 sec]: learning rate : 0.000050 loss : 0.592612 +[05:05:51.965] iteration 22300 [1274.87 sec]: learning rate : 0.000050 loss : 0.521813 +[05:07:18.658] iteration 22400 [1361.57 sec]: learning rate : 0.000050 loss : 0.349943 +[05:08:45.409] iteration 22500 [1448.32 sec]: learning rate : 0.000050 loss : 0.577279 +[05:10:12.164] iteration 22600 [1535.07 sec]: learning rate : 0.000050 loss : 0.606772 +[05:11:38.825] iteration 22700 [1621.73 sec]: learning rate : 0.000050 loss : 0.293787 +[05:13:05.548] iteration 22800 [1708.46 sec]: learning rate : 0.000050 loss : 0.514689 +[05:14:32.260] iteration 22900 [1795.17 sec]: learning rate : 0.000050 loss : 0.288859 +[05:14:43.494] Epoch 10 Evaluation: +[05:15:33.613] average MSE: 0.05718550086021423 average PSNR: 27.970714675657167 average SSIM: 0.5939281987650734 +[05:16:49.236] iteration 23000 [75.56 sec]: learning rate : 0.000050 loss : 0.887481 +[05:18:15.984] iteration 23100 [162.31 sec]: learning rate : 0.000050 loss : 0.588766 +[05:19:42.738] iteration 23200 [249.06 sec]: learning rate : 0.000050 loss : 0.786999 +[05:21:09.411] iteration 23300 [335.73 sec]: learning rate : 0.000050 loss : 0.263807 +[05:22:36.115] iteration 23400 [422.44 sec]: learning rate : 0.000050 loss : 0.468895 +[05:24:02.829] iteration 23500 [509.15 sec]: learning rate : 0.000050 loss : 0.667018 +[05:25:29.492] iteration 23600 [595.82 sec]: learning rate : 0.000050 loss : 0.799004 +[05:26:56.174] iteration 23700 [682.50 sec]: learning rate : 0.000050 loss : 0.414058 +[05:28:22.902] iteration 23800 [769.23 sec]: learning rate : 0.000050 loss : 0.720424 +[05:29:49.532] iteration 23900 [855.86 sec]: learning rate : 0.000050 loss : 0.594778 +[05:31:16.201] iteration 24000 [942.52 sec]: learning rate : 0.000050 loss : 0.557455 +[05:32:42.859] iteration 24100 [1029.18 sec]: learning rate : 0.000050 loss : 0.600025 +[05:34:09.596] iteration 24200 [1115.92 sec]: learning rate : 0.000050 loss : 0.357790 +[05:35:36.286] iteration 24300 [1202.61 sec]: learning rate : 0.000050 loss : 0.477838 +[05:37:02.943] iteration 24400 [1289.27 sec]: learning rate : 0.000050 loss : 0.685633 +[05:38:29.682] iteration 24500 [1376.01 sec]: learning rate : 0.000050 loss : 0.698611 +[05:39:56.375] iteration 24600 [1462.70 sec]: learning rate : 0.000050 loss : 1.115312 +[05:41:23.080] iteration 24700 [1549.40 sec]: learning rate : 0.000050 loss : 0.822841 +[05:42:49.788] iteration 24800 [1636.11 sec]: learning rate : 0.000050 loss : 0.807623 +[05:44:16.456] iteration 24900 [1722.78 sec]: learning rate : 0.000050 loss : 0.572549 +[05:45:39.690] Epoch 11 Evaluation: +[05:46:30.235] average MSE: 0.056621160358190536 average PSNR: 28.024328435328894 average SSIM: 0.5959611433251605 +[05:46:33.978] iteration 25000 [3.68 sec]: learning rate : 0.000050 loss : 0.450612 +[05:48:00.618] iteration 25100 [90.32 sec]: learning rate : 0.000050 loss : 0.589377 +[05:49:27.367] iteration 25200 [177.07 sec]: learning rate : 0.000050 loss : 0.601901 +[05:50:54.096] iteration 25300 [263.80 sec]: learning rate : 0.000050 loss : 0.728163 +[05:52:20.755] iteration 25400 [350.46 sec]: learning rate : 0.000050 loss : 1.275169 +[05:53:47.461] iteration 25500 [437.16 sec]: learning rate : 0.000050 loss : 0.351298 +[05:55:14.194] iteration 25600 [523.90 sec]: learning rate : 0.000050 loss : 1.171507 +[05:56:40.856] iteration 25700 [610.56 sec]: learning rate : 0.000050 loss : 0.370142 +[05:58:07.605] iteration 25800 [697.31 sec]: learning rate : 0.000050 loss : 0.936035 +[05:59:34.331] iteration 25900 [784.03 sec]: learning rate : 0.000050 loss : 0.388894 +[06:01:01.012] iteration 26000 [870.72 sec]: learning rate : 0.000050 loss : 0.915216 +[06:02:27.736] iteration 26100 [957.44 sec]: learning rate : 0.000050 loss : 0.747489 +[06:03:54.391] iteration 26200 [1044.09 sec]: learning rate : 0.000050 loss : 0.458830 +[06:05:21.123] iteration 26300 [1130.82 sec]: learning rate : 0.000050 loss : 0.596690 +[06:06:47.856] iteration 26400 [1217.56 sec]: learning rate : 0.000050 loss : 0.430359 +[06:08:14.524] iteration 26500 [1304.23 sec]: learning rate : 0.000050 loss : 0.669917 +[06:09:41.266] iteration 26600 [1390.97 sec]: learning rate : 0.000050 loss : 0.540063 +[06:11:08.008] iteration 26700 [1477.71 sec]: learning rate : 0.000050 loss : 0.548021 +[06:12:34.687] iteration 26800 [1564.39 sec]: learning rate : 0.000050 loss : 0.432187 +[06:14:01.440] iteration 26900 [1651.14 sec]: learning rate : 0.000050 loss : 0.609635 +[06:15:28.172] iteration 27000 [1737.87 sec]: learning rate : 0.000050 loss : 0.605597 +[06:16:36.643] Epoch 12 Evaluation: +[06:17:26.634] average MSE: 0.05703386291861534 average PSNR: 27.99292785219541 average SSIM: 0.5950611176847864 +[06:17:45.081] iteration 27100 [18.38 sec]: learning rate : 0.000050 loss : 0.691979 +[06:19:11.855] iteration 27200 [105.16 sec]: learning rate : 0.000050 loss : 0.742221 +[06:20:38.595] iteration 27300 [191.90 sec]: learning rate : 0.000050 loss : 0.506363 +[06:22:05.269] iteration 27400 [278.57 sec]: learning rate : 0.000050 loss : 0.943229 +[06:23:32.041] iteration 27500 [365.34 sec]: learning rate : 0.000050 loss : 0.448066 +[06:24:58.704] iteration 27600 [452.01 sec]: learning rate : 0.000050 loss : 0.627242 +[06:26:25.408] iteration 27700 [538.71 sec]: learning rate : 0.000050 loss : 0.548009 +[06:27:52.133] iteration 27800 [625.44 sec]: learning rate : 0.000050 loss : 0.529415 +[06:29:18.825] iteration 27900 [712.13 sec]: learning rate : 0.000050 loss : 0.346443 +[06:30:45.590] iteration 28000 [798.89 sec]: learning rate : 0.000050 loss : 0.524403 +[06:32:12.347] iteration 28100 [885.65 sec]: learning rate : 0.000050 loss : 0.613265 +[06:33:39.049] iteration 28200 [972.35 sec]: learning rate : 0.000050 loss : 0.788386 +[06:35:05.843] iteration 28300 [1059.15 sec]: learning rate : 0.000050 loss : 0.707291 +[06:36:32.601] iteration 28400 [1145.91 sec]: learning rate : 0.000050 loss : 0.666513 +[06:37:59.317] iteration 28500 [1232.62 sec]: learning rate : 0.000050 loss : 0.553273 +[06:39:26.069] iteration 28600 [1319.37 sec]: learning rate : 0.000050 loss : 0.461110 +[06:40:52.774] iteration 28700 [1406.08 sec]: learning rate : 0.000050 loss : 0.743283 +[06:42:19.535] iteration 28800 [1492.84 sec]: learning rate : 0.000050 loss : 0.489518 +[06:43:46.338] iteration 28900 [1579.64 sec]: learning rate : 0.000050 loss : 0.498388 +[06:45:13.045] iteration 29000 [1666.35 sec]: learning rate : 0.000050 loss : 0.843259 +[06:46:39.821] iteration 29100 [1753.12 sec]: learning rate : 0.000050 loss : 0.587710 +[06:47:33.569] Epoch 13 Evaluation: +[06:48:24.118] average MSE: 0.05641324818134308 average PSNR: 28.046936387970586 average SSIM: 0.5961877034361103 +[06:48:57.437] iteration 29200 [33.26 sec]: learning rate : 0.000050 loss : 0.693869 +[06:50:24.112] iteration 29300 [119.93 sec]: learning rate : 0.000050 loss : 0.571310 +[06:51:50.878] iteration 29400 [206.70 sec]: learning rate : 0.000050 loss : 0.803224 +[06:53:17.591] iteration 29500 [293.41 sec]: learning rate : 0.000050 loss : 0.313115 +[06:54:44.448] iteration 29600 [380.27 sec]: learning rate : 0.000050 loss : 0.446277 +[06:56:11.231] iteration 29700 [467.05 sec]: learning rate : 0.000050 loss : 0.827076 +[06:57:37.978] iteration 29800 [553.80 sec]: learning rate : 0.000050 loss : 0.424837 +[06:59:04.766] iteration 29900 [640.59 sec]: learning rate : 0.000050 loss : 0.323419 +[07:00:31.570] iteration 30000 [727.39 sec]: learning rate : 0.000050 loss : 0.554499 +[07:01:58.286] iteration 30100 [814.11 sec]: learning rate : 0.000050 loss : 0.632158 +[07:03:25.071] iteration 30200 [900.89 sec]: learning rate : 0.000050 loss : 0.329621 +[07:04:51.881] iteration 30300 [987.71 sec]: learning rate : 0.000050 loss : 0.731229 +[07:06:18.631] iteration 30400 [1074.45 sec]: learning rate : 0.000050 loss : 0.834342 +[07:07:45.395] iteration 30500 [1161.23 sec]: learning rate : 0.000050 loss : 0.872011 +[07:09:12.230] iteration 30600 [1248.05 sec]: learning rate : 0.000050 loss : 0.562032 +[07:10:38.992] iteration 30700 [1334.81 sec]: learning rate : 0.000050 loss : 0.476894 +[07:12:05.788] iteration 30800 [1421.61 sec]: learning rate : 0.000050 loss : 0.525584 +[07:13:32.607] iteration 30900 [1508.43 sec]: learning rate : 0.000050 loss : 0.687764 +[07:14:59.345] iteration 31000 [1595.17 sec]: learning rate : 0.000050 loss : 0.569546 +[07:16:26.096] iteration 31100 [1681.92 sec]: learning rate : 0.000050 loss : 0.411899 +[07:17:52.912] iteration 31200 [1768.73 sec]: learning rate : 0.000050 loss : 0.662882 +[07:18:31.916] Epoch 14 Evaluation: +[07:19:24.210] average MSE: 0.056387584656476974 average PSNR: 28.04292470828196 average SSIM: 0.5962665795910712 +[07:20:12.148] iteration 31300 [47.88 sec]: learning rate : 0.000050 loss : 0.822622 +[07:21:38.981] iteration 31400 [134.71 sec]: learning rate : 0.000050 loss : 0.556343 +[07:23:05.774] iteration 31500 [221.50 sec]: learning rate : 0.000050 loss : 0.599553 +[07:24:32.513] iteration 31600 [308.24 sec]: learning rate : 0.000050 loss : 0.397999 +[07:25:59.292] iteration 31700 [395.02 sec]: learning rate : 0.000050 loss : 0.540428 +[07:27:26.037] iteration 31800 [481.76 sec]: learning rate : 0.000050 loss : 0.742871 +[07:28:52.846] iteration 31900 [568.57 sec]: learning rate : 0.000050 loss : 0.704103 +[07:30:19.664] iteration 32000 [655.39 sec]: learning rate : 0.000050 loss : 0.641131 +[07:31:46.478] iteration 32100 [742.20 sec]: learning rate : 0.000050 loss : 0.522990 +[07:33:13.313] iteration 32200 [829.04 sec]: learning rate : 0.000050 loss : 0.707133 +[07:34:40.072] iteration 32300 [915.80 sec]: learning rate : 0.000050 loss : 0.703620 +[07:36:06.871] iteration 32400 [1002.60 sec]: learning rate : 0.000050 loss : 0.307118 +[07:37:33.698] iteration 32500 [1089.43 sec]: learning rate : 0.000050 loss : 0.623223 +[07:39:00.455] iteration 32600 [1176.18 sec]: learning rate : 0.000050 loss : 0.400727 +[07:40:27.236] iteration 32700 [1262.96 sec]: learning rate : 0.000050 loss : 0.581723 +[07:41:53.994] iteration 32800 [1349.72 sec]: learning rate : 0.000050 loss : 0.579330 +[07:43:20.779] iteration 32900 [1436.51 sec]: learning rate : 0.000050 loss : 0.452466 +[07:44:47.558] iteration 33000 [1523.29 sec]: learning rate : 0.000050 loss : 0.667161 +[07:46:14.332] iteration 33100 [1610.06 sec]: learning rate : 0.000050 loss : 0.579581 +[07:47:41.150] iteration 33200 [1696.88 sec]: learning rate : 0.000050 loss : 0.529430 +[07:49:07.885] iteration 33300 [1783.61 sec]: learning rate : 0.000050 loss : 0.546139 +[07:49:32.199] Epoch 15 Evaluation: +[07:50:23.537] average MSE: 0.05641862750053406 average PSNR: 28.059351683417507 average SSIM: 0.5974735257880057 +[07:51:26.201] iteration 33400 [62.60 sec]: learning rate : 0.000050 loss : 0.510199 +[07:52:53.042] iteration 33500 [149.44 sec]: learning rate : 0.000050 loss : 0.755949 +[07:54:19.802] iteration 33600 [236.20 sec]: learning rate : 0.000050 loss : 0.610239 +[07:55:46.581] iteration 33700 [322.98 sec]: learning rate : 0.000050 loss : 0.613713 +[07:57:13.376] iteration 33800 [409.78 sec]: learning rate : 0.000050 loss : 0.400652 +[07:58:40.082] iteration 33900 [496.48 sec]: learning rate : 0.000050 loss : 0.773322 +[08:00:06.842] iteration 34000 [583.24 sec]: learning rate : 0.000050 loss : 0.532960 +[08:01:33.570] iteration 34100 [669.97 sec]: learning rate : 0.000050 loss : 0.686563 +[08:03:00.339] iteration 34200 [756.74 sec]: learning rate : 0.000050 loss : 0.799760 +[08:04:27.139] iteration 34300 [843.54 sec]: learning rate : 0.000050 loss : 0.693922 +[08:05:53.855] iteration 34400 [930.26 sec]: learning rate : 0.000050 loss : 0.570900 +[08:07:20.621] iteration 34500 [1017.02 sec]: learning rate : 0.000050 loss : 0.423646 +[08:08:47.312] iteration 34600 [1103.71 sec]: learning rate : 0.000050 loss : 0.812457 +[08:10:14.069] iteration 34700 [1190.47 sec]: learning rate : 0.000050 loss : 0.464563 +[08:11:40.843] iteration 34800 [1277.24 sec]: learning rate : 0.000050 loss : 0.585855 +[08:13:07.565] iteration 34900 [1363.97 sec]: learning rate : 0.000050 loss : 0.343388 +[08:14:34.312] iteration 35000 [1450.71 sec]: learning rate : 0.000050 loss : 1.125863 +[08:16:01.008] iteration 35100 [1537.41 sec]: learning rate : 0.000050 loss : 0.900040 +[08:17:27.691] iteration 35200 [1624.09 sec]: learning rate : 0.000050 loss : 0.475522 +[08:18:54.419] iteration 35300 [1710.82 sec]: learning rate : 0.000050 loss : 0.496932 +[08:20:21.067] iteration 35400 [1797.47 sec]: learning rate : 0.000050 loss : 0.615755 +[08:20:30.577] Epoch 16 Evaluation: +[08:21:20.997] average MSE: 0.05616777017712593 average PSNR: 28.071600058122726 average SSIM: 0.597287583047336 +[08:22:38.442] iteration 35500 [77.38 sec]: learning rate : 0.000050 loss : 0.551961 +[08:24:05.092] iteration 35600 [164.03 sec]: learning rate : 0.000050 loss : 0.798514 +[08:25:31.706] iteration 35700 [250.65 sec]: learning rate : 0.000050 loss : 0.605000 +[08:26:58.400] iteration 35800 [337.34 sec]: learning rate : 0.000050 loss : 0.476941 +[08:28:25.080] iteration 35900 [424.02 sec]: learning rate : 0.000050 loss : 0.641124 +[08:29:51.719] iteration 36000 [510.66 sec]: learning rate : 0.000050 loss : 1.118271 +[08:31:18.414] iteration 36100 [597.35 sec]: learning rate : 0.000050 loss : 1.191278 +[08:32:45.049] iteration 36200 [683.99 sec]: learning rate : 0.000050 loss : 0.706728 +[08:34:11.764] iteration 36300 [770.71 sec]: learning rate : 0.000050 loss : 0.585427 +[08:35:38.474] iteration 36400 [857.42 sec]: learning rate : 0.000050 loss : 0.476287 +[08:37:05.140] iteration 36500 [944.08 sec]: learning rate : 0.000050 loss : 0.737714 +[08:38:31.859] iteration 36600 [1030.80 sec]: learning rate : 0.000050 loss : 0.396321 +[08:39:58.570] iteration 36700 [1117.51 sec]: learning rate : 0.000050 loss : 0.898999 +[08:41:25.251] iteration 36800 [1204.19 sec]: learning rate : 0.000050 loss : 0.395066 +[08:42:51.999] iteration 36900 [1290.94 sec]: learning rate : 0.000050 loss : 0.723672 +[08:44:18.768] iteration 37000 [1377.71 sec]: learning rate : 0.000050 loss : 0.579066 +[08:45:45.480] iteration 37100 [1464.42 sec]: learning rate : 0.000050 loss : 0.701955 +[08:47:12.262] iteration 37200 [1551.20 sec]: learning rate : 0.000050 loss : 0.789603 +[08:48:39.022] iteration 37300 [1637.96 sec]: learning rate : 0.000050 loss : 0.489965 +[08:50:05.749] iteration 37400 [1724.69 sec]: learning rate : 0.000050 loss : 0.670437 +[08:51:27.335] Epoch 17 Evaluation: +[08:52:17.773] average MSE: 0.05619131401181221 average PSNR: 28.077175637360142 average SSIM: 0.5981826828587928 +[08:52:23.237] iteration 37500 [5.40 sec]: learning rate : 0.000050 loss : 0.433613 +[08:53:50.083] iteration 37600 [92.25 sec]: learning rate : 0.000050 loss : 0.369885 +[08:55:16.852] iteration 37700 [179.02 sec]: learning rate : 0.000050 loss : 0.547699 +[08:56:43.658] iteration 37800 [265.82 sec]: learning rate : 0.000050 loss : 1.093894 +[08:58:10.431] iteration 37900 [352.59 sec]: learning rate : 0.000050 loss : 1.099108 +[08:59:37.297] iteration 38000 [439.47 sec]: learning rate : 0.000050 loss : 0.868489 +[09:01:04.156] iteration 38100 [526.32 sec]: learning rate : 0.000050 loss : 0.334933 +[09:02:30.937] iteration 38200 [613.10 sec]: learning rate : 0.000050 loss : 0.661565 +[09:03:57.802] iteration 38300 [699.97 sec]: learning rate : 0.000050 loss : 0.351203 +[09:05:24.576] iteration 38400 [786.74 sec]: learning rate : 0.000050 loss : 1.140485 +[09:06:51.412] iteration 38500 [873.58 sec]: learning rate : 0.000050 loss : 0.701969 +[09:08:18.232] iteration 38600 [960.40 sec]: learning rate : 0.000050 loss : 0.676794 +[09:09:44.995] iteration 38700 [1047.16 sec]: learning rate : 0.000050 loss : 0.521871 +[09:11:11.879] iteration 38800 [1134.04 sec]: learning rate : 0.000050 loss : 0.388095 +[09:12:38.700] iteration 38900 [1220.86 sec]: learning rate : 0.000050 loss : 0.644514 +[09:14:05.503] iteration 39000 [1307.67 sec]: learning rate : 0.000050 loss : 0.310803 +[09:15:32.262] iteration 39100 [1394.43 sec]: learning rate : 0.000050 loss : 0.574254 +[09:16:59.010] iteration 39200 [1481.17 sec]: learning rate : 0.000050 loss : 0.674483 +[09:18:25.828] iteration 39300 [1567.99 sec]: learning rate : 0.000050 loss : 0.715404 +[09:19:52.596] iteration 39400 [1654.76 sec]: learning rate : 0.000050 loss : 0.520284 +[09:21:19.311] iteration 39500 [1741.47 sec]: learning rate : 0.000050 loss : 0.537421 +[09:22:26.123] Epoch 18 Evaluation: +[09:23:18.193] average MSE: 0.05590777471661568 average PSNR: 28.12215676941479 average SSIM: 0.5989879458746671 +[09:23:38.376] iteration 39600 [20.12 sec]: learning rate : 0.000050 loss : 0.985303 +[09:25:05.129] iteration 39700 [106.87 sec]: learning rate : 0.000050 loss : 0.743097 +[09:26:31.836] iteration 39800 [193.58 sec]: learning rate : 0.000050 loss : 0.700598 +[09:27:58.602] iteration 39900 [280.35 sec]: learning rate : 0.000050 loss : 0.791192 +[09:29:25.331] iteration 40000 [367.07 sec]: learning rate : 0.000013 loss : 0.650495 +[09:29:25.492] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_40000.pth +[09:30:52.268] iteration 40100 [454.01 sec]: learning rate : 0.000025 loss : 0.641612 +[09:32:19.085] iteration 40200 [540.83 sec]: learning rate : 0.000025 loss : 0.560708 +[09:33:45.800] iteration 40300 [627.54 sec]: learning rate : 0.000025 loss : 0.704052 +[09:35:12.597] iteration 40400 [714.34 sec]: learning rate : 0.000025 loss : 0.608707 +[09:36:39.357] iteration 40500 [801.10 sec]: learning rate : 0.000025 loss : 0.479947 +[09:38:06.091] iteration 40600 [887.84 sec]: learning rate : 0.000025 loss : 0.803444 +[09:39:32.859] iteration 40700 [974.60 sec]: learning rate : 0.000025 loss : 0.625987 +[09:40:59.613] iteration 40800 [1061.36 sec]: learning rate : 0.000025 loss : 0.511845 +[09:42:26.314] iteration 40900 [1148.06 sec]: learning rate : 0.000025 loss : 0.434749 +[09:43:53.119] iteration 41000 [1234.86 sec]: learning rate : 0.000025 loss : 0.462691 +[09:45:19.904] iteration 41100 [1321.65 sec]: learning rate : 0.000025 loss : 0.362868 +[09:46:46.639] iteration 41200 [1408.38 sec]: learning rate : 0.000025 loss : 0.792518 +[09:48:13.394] iteration 41300 [1495.14 sec]: learning rate : 0.000025 loss : 0.373635 +[09:49:40.157] iteration 41400 [1581.90 sec]: learning rate : 0.000025 loss : 0.655068 +[09:51:06.854] iteration 41500 [1668.60 sec]: learning rate : 0.000025 loss : 0.731129 +[09:52:33.619] iteration 41600 [1755.36 sec]: learning rate : 0.000025 loss : 0.503482 +[09:53:25.620] Epoch 19 Evaluation: +[09:54:18.004] average MSE: 0.055765923112630844 average PSNR: 28.13698908653535 average SSIM: 0.5993716128758847 +[09:54:53.039] iteration 41700 [34.97 sec]: learning rate : 0.000025 loss : 0.493538 +[09:56:19.728] iteration 41800 [121.66 sec]: learning rate : 0.000025 loss : 0.405636 +[09:57:46.477] iteration 41900 [208.41 sec]: learning rate : 0.000025 loss : 0.489686 +[09:59:13.264] iteration 42000 [295.20 sec]: learning rate : 0.000025 loss : 0.903025 +[10:00:39.987] iteration 42100 [381.92 sec]: learning rate : 0.000025 loss : 0.516038 +[10:02:06.734] iteration 42200 [468.67 sec]: learning rate : 0.000025 loss : 0.630207 +[10:03:33.529] iteration 42300 [555.46 sec]: learning rate : 0.000025 loss : 0.383436 +[10:05:00.276] iteration 42400 [642.21 sec]: learning rate : 0.000025 loss : 0.556182 +[10:06:27.070] iteration 42500 [729.00 sec]: learning rate : 0.000025 loss : 0.603861 +[10:07:53.873] iteration 42600 [815.81 sec]: learning rate : 0.000025 loss : 0.493579 +[10:09:20.615] iteration 42700 [902.55 sec]: learning rate : 0.000025 loss : 0.685034 +[10:10:47.444] iteration 42800 [989.38 sec]: learning rate : 0.000025 loss : 0.720820 +[10:12:14.259] iteration 42900 [1076.25 sec]: learning rate : 0.000025 loss : 0.559108 +[10:13:40.983] iteration 43000 [1162.92 sec]: learning rate : 0.000025 loss : 1.012941 +[10:15:07.763] iteration 43100 [1249.70 sec]: learning rate : 0.000025 loss : 0.432209 +[10:16:34.518] iteration 43200 [1336.45 sec]: learning rate : 0.000025 loss : 0.709668 +[10:18:01.323] iteration 43300 [1423.26 sec]: learning rate : 0.000025 loss : 0.803380 +[10:19:28.155] iteration 43400 [1510.09 sec]: learning rate : 0.000025 loss : 0.649528 +[10:20:54.947] iteration 43500 [1596.88 sec]: learning rate : 0.000025 loss : 0.548576 +[10:22:21.683] iteration 43600 [1683.62 sec]: learning rate : 0.000025 loss : 0.532071 +[10:23:48.491] iteration 43700 [1770.42 sec]: learning rate : 0.000025 loss : 0.702239 +[10:24:25.745] Epoch 20 Evaluation: +[10:25:16.148] average MSE: 0.05553659796714783 average PSNR: 28.151431452976606 average SSIM: 0.5998129017063919 +[10:26:05.800] iteration 43800 [49.59 sec]: learning rate : 0.000025 loss : 0.490497 +[10:27:32.614] iteration 43900 [136.40 sec]: learning rate : 0.000025 loss : 0.551505 +[10:28:59.414] iteration 44000 [223.20 sec]: learning rate : 0.000025 loss : 0.596241 +[10:30:26.137] iteration 44100 [309.93 sec]: learning rate : 0.000025 loss : 0.724240 +[10:31:52.918] iteration 44200 [396.71 sec]: learning rate : 0.000025 loss : 1.006566 +[10:33:19.673] iteration 44300 [483.46 sec]: learning rate : 0.000025 loss : 0.910148 +[10:34:46.427] iteration 44400 [570.22 sec]: learning rate : 0.000025 loss : 0.539202 +[10:36:13.238] iteration 44500 [657.03 sec]: learning rate : 0.000025 loss : 0.642790 +[10:37:39.974] iteration 44600 [743.76 sec]: learning rate : 0.000025 loss : 1.284309 +[10:39:06.763] iteration 44700 [830.55 sec]: learning rate : 0.000025 loss : 0.702786 +[10:40:33.567] iteration 44800 [917.36 sec]: learning rate : 0.000025 loss : 1.185821 +[10:42:00.292] iteration 44900 [1004.08 sec]: learning rate : 0.000025 loss : 1.328419 +[10:43:27.070] iteration 45000 [1090.86 sec]: learning rate : 0.000025 loss : 0.598202 +[10:44:53.811] iteration 45100 [1177.60 sec]: learning rate : 0.000025 loss : 0.542721 +[10:46:20.593] iteration 45200 [1264.38 sec]: learning rate : 0.000025 loss : 0.379263 +[10:47:47.412] iteration 45300 [1351.20 sec]: learning rate : 0.000025 loss : 0.534928 +[10:49:14.134] iteration 45400 [1437.92 sec]: learning rate : 0.000025 loss : 0.785568 +[10:50:40.928] iteration 45500 [1524.72 sec]: learning rate : 0.000025 loss : 0.567394 +[10:52:07.707] iteration 45600 [1611.50 sec]: learning rate : 0.000025 loss : 0.678883 +[10:53:34.429] iteration 45700 [1698.22 sec]: learning rate : 0.000025 loss : 0.628464 +[10:55:01.201] iteration 45800 [1784.99 sec]: learning rate : 0.000025 loss : 0.569850 +[10:55:23.727] Epoch 21 Evaluation: +[10:56:15.813] average MSE: 0.055470824241638184 average PSNR: 28.162714961473473 average SSIM: 0.6000344071892255 +[10:57:20.201] iteration 45900 [64.32 sec]: learning rate : 0.000025 loss : 0.264647 +[10:58:47.028] iteration 46000 [151.15 sec]: learning rate : 0.000025 loss : 0.341947 +[11:00:13.827] iteration 46100 [237.95 sec]: learning rate : 0.000025 loss : 0.626456 +[11:01:40.562] iteration 46200 [324.69 sec]: learning rate : 0.000025 loss : 0.601735 +[11:03:07.369] iteration 46300 [411.49 sec]: learning rate : 0.000025 loss : 0.882218 +[11:04:34.182] iteration 46400 [498.31 sec]: learning rate : 0.000025 loss : 0.625942 +[11:06:00.934] iteration 46500 [585.06 sec]: learning rate : 0.000025 loss : 0.880569 +[11:07:27.695] iteration 46600 [671.82 sec]: learning rate : 0.000025 loss : 0.904312 +[11:08:54.470] iteration 46700 [758.59 sec]: learning rate : 0.000025 loss : 0.290488 +[11:10:21.204] iteration 46800 [845.33 sec]: learning rate : 0.000025 loss : 0.791323 +[11:11:48.006] iteration 46900 [932.13 sec]: learning rate : 0.000025 loss : 0.558582 +[11:13:14.763] iteration 47000 [1018.89 sec]: learning rate : 0.000025 loss : 0.327169 +[11:14:41.556] iteration 47100 [1105.68 sec]: learning rate : 0.000025 loss : 0.780772 +[11:16:08.378] iteration 47200 [1192.50 sec]: learning rate : 0.000025 loss : 0.515775 +[11:17:35.123] iteration 47300 [1279.25 sec]: learning rate : 0.000025 loss : 0.373249 +[11:19:01.929] iteration 47400 [1366.05 sec]: learning rate : 0.000025 loss : 0.402174 +[11:20:28.766] iteration 47500 [1452.89 sec]: learning rate : 0.000025 loss : 0.477670 +[11:21:55.511] iteration 47600 [1539.64 sec]: learning rate : 0.000025 loss : 0.551714 +[11:23:22.369] iteration 47700 [1626.49 sec]: learning rate : 0.000025 loss : 0.403655 +[11:24:49.171] iteration 47800 [1713.30 sec]: learning rate : 0.000025 loss : 0.307443 +[11:26:15.927] iteration 47900 [1800.05 sec]: learning rate : 0.000025 loss : 0.464305 +[11:26:23.713] Epoch 22 Evaluation: +[11:27:16.278] average MSE: 0.0553852915763855 average PSNR: 28.173225073100863 average SSIM: 0.6010319817095279 +[11:28:35.556] iteration 48000 [79.22 sec]: learning rate : 0.000025 loss : 0.642741 +[11:30:02.341] iteration 48100 [166.00 sec]: learning rate : 0.000025 loss : 0.855257 +[11:31:29.089] iteration 48200 [252.75 sec]: learning rate : 0.000025 loss : 0.688718 +[11:32:55.892] iteration 48300 [339.55 sec]: learning rate : 0.000025 loss : 0.521772 +[11:34:22.719] iteration 48400 [426.38 sec]: learning rate : 0.000025 loss : 0.549633 +[11:35:49.458] iteration 48500 [513.12 sec]: learning rate : 0.000025 loss : 0.781367 +[11:37:16.243] iteration 48600 [599.90 sec]: learning rate : 0.000025 loss : 0.562918 +[11:38:42.991] iteration 48700 [686.65 sec]: learning rate : 0.000025 loss : 0.518579 +[11:40:09.770] iteration 48800 [773.43 sec]: learning rate : 0.000025 loss : 0.550798 +[11:41:36.568] iteration 48900 [860.23 sec]: learning rate : 0.000025 loss : 0.457363 +[11:43:03.357] iteration 49000 [947.02 sec]: learning rate : 0.000025 loss : 0.510794 +[11:44:30.146] iteration 49100 [1033.80 sec]: learning rate : 0.000025 loss : 0.643216 +[11:45:56.946] iteration 49200 [1120.61 sec]: learning rate : 0.000025 loss : 0.629388 +[11:47:23.717] iteration 49300 [1207.38 sec]: learning rate : 0.000025 loss : 0.409376 +[11:48:50.549] iteration 49400 [1294.21 sec]: learning rate : 0.000025 loss : 0.586676 +[11:50:17.377] iteration 49500 [1381.04 sec]: learning rate : 0.000025 loss : 0.848281 +[11:51:44.139] iteration 49600 [1467.80 sec]: learning rate : 0.000025 loss : 0.720532 +[11:53:10.959] iteration 49700 [1554.62 sec]: learning rate : 0.000025 loss : 0.832828 +[11:54:37.786] iteration 49800 [1641.45 sec]: learning rate : 0.000025 loss : 0.334596 +[11:56:04.572] iteration 49900 [1728.23 sec]: learning rate : 0.000025 loss : 0.544026 +[11:57:24.438] Epoch 23 Evaluation: +[11:58:14.739] average MSE: 0.055305372923612595 average PSNR: 28.181915368834655 average SSIM: 0.6010085561147456 +[11:58:21.929] iteration 50000 [7.13 sec]: learning rate : 0.000025 loss : 0.495652 +[11:59:48.792] iteration 50100 [93.99 sec]: learning rate : 0.000025 loss : 0.402836 +[12:01:15.543] iteration 50200 [180.74 sec]: learning rate : 0.000025 loss : 0.392229 +[12:02:42.343] iteration 50300 [267.54 sec]: learning rate : 0.000025 loss : 0.642423 +[12:04:09.168] iteration 50400 [354.37 sec]: learning rate : 0.000025 loss : 0.795557 +[12:05:35.919] iteration 50500 [441.13 sec]: learning rate : 0.000025 loss : 0.747687 +[12:07:02.760] iteration 50600 [527.96 sec]: learning rate : 0.000025 loss : 0.405439 +[12:08:29.542] iteration 50700 [614.74 sec]: learning rate : 0.000025 loss : 0.915773 +[12:09:56.392] iteration 50800 [701.59 sec]: learning rate : 0.000025 loss : 0.523112 +[12:11:23.249] iteration 50900 [788.45 sec]: learning rate : 0.000025 loss : 0.594934 +[12:12:50.037] iteration 51000 [875.24 sec]: learning rate : 0.000025 loss : 0.824571 +[12:14:16.928] iteration 51100 [962.13 sec]: learning rate : 0.000025 loss : 0.523579 +[12:15:43.773] iteration 51200 [1048.97 sec]: learning rate : 0.000025 loss : 0.654867 +[12:17:10.564] iteration 51300 [1135.76 sec]: learning rate : 0.000025 loss : 0.286649 +[12:18:37.428] iteration 51400 [1222.63 sec]: learning rate : 0.000025 loss : 0.534003 +[12:20:04.307] iteration 51500 [1309.51 sec]: learning rate : 0.000025 loss : 0.442144 +[12:21:31.122] iteration 51600 [1396.32 sec]: learning rate : 0.000025 loss : 0.430050 +[12:22:58.056] iteration 51700 [1483.25 sec]: learning rate : 0.000025 loss : 0.566738 +[12:24:24.941] iteration 51800 [1570.14 sec]: learning rate : 0.000025 loss : 0.582287 +[12:25:51.758] iteration 51900 [1656.96 sec]: learning rate : 0.000025 loss : 0.741240 +[12:27:18.650] iteration 52000 [1743.85 sec]: learning rate : 0.000025 loss : 0.319677 +[12:28:23.753] Epoch 24 Evaluation: +[12:29:13.867] average MSE: 0.05557478591799736 average PSNR: 28.16300561764633 average SSIM: 0.6007983512930862 +[12:29:35.812] iteration 52100 [21.88 sec]: learning rate : 0.000025 loss : 0.923510 +[12:31:02.710] iteration 52200 [108.78 sec]: learning rate : 0.000025 loss : 0.286816 +[12:32:29.576] iteration 52300 [195.65 sec]: learning rate : 0.000025 loss : 0.408619 +[12:33:56.407] iteration 52400 [282.48 sec]: learning rate : 0.000025 loss : 0.718120 +[12:35:23.292] iteration 52500 [369.36 sec]: learning rate : 0.000025 loss : 0.649106 +[12:36:50.166] iteration 52600 [456.24 sec]: learning rate : 0.000025 loss : 0.729588 +[12:38:16.998] iteration 52700 [543.07 sec]: learning rate : 0.000025 loss : 0.731861 +[12:39:43.913] iteration 52800 [629.98 sec]: learning rate : 0.000025 loss : 0.579488 +[12:41:10.798] iteration 52900 [716.87 sec]: learning rate : 0.000025 loss : 0.648896 +[12:42:37.638] iteration 53000 [803.71 sec]: learning rate : 0.000025 loss : 0.751492 +[12:44:04.573] iteration 53100 [890.64 sec]: learning rate : 0.000025 loss : 0.596804 +[12:45:31.478] iteration 53200 [977.55 sec]: learning rate : 0.000025 loss : 0.498356 +[12:46:58.322] iteration 53300 [1064.39 sec]: learning rate : 0.000025 loss : 0.772012 +[12:48:25.230] iteration 53400 [1151.30 sec]: learning rate : 0.000025 loss : 0.776900 +[12:49:52.140] iteration 53500 [1238.21 sec]: learning rate : 0.000025 loss : 0.514662 +[12:51:18.997] iteration 53600 [1325.07 sec]: learning rate : 0.000025 loss : 0.500727 +[12:52:45.921] iteration 53700 [1411.99 sec]: learning rate : 0.000025 loss : 0.568657 +[12:54:12.834] iteration 53800 [1498.90 sec]: learning rate : 0.000025 loss : 0.384159 +[12:55:39.691] iteration 53900 [1585.76 sec]: learning rate : 0.000025 loss : 0.726393 +[12:57:06.575] iteration 54000 [1672.65 sec]: learning rate : 0.000025 loss : 0.699668 +[12:58:33.423] iteration 54100 [1759.49 sec]: learning rate : 0.000025 loss : 0.609216 +[12:59:23.822] Epoch 25 Evaluation: +[13:00:14.172] average MSE: 0.055239781737327576 average PSNR: 28.192217496736266 average SSIM: 0.6001933424882355 +[13:00:50.880] iteration 54200 [36.64 sec]: learning rate : 0.000025 loss : 0.931894 +[13:02:17.818] iteration 54300 [123.58 sec]: learning rate : 0.000025 loss : 0.578118 +[13:03:44.667] iteration 54400 [210.43 sec]: learning rate : 0.000025 loss : 0.529482 +[13:05:11.569] iteration 54500 [297.33 sec]: learning rate : 0.000025 loss : 0.658839 +[13:06:38.446] iteration 54600 [384.21 sec]: learning rate : 0.000025 loss : 0.455375 +[13:08:05.298] iteration 54700 [471.06 sec]: learning rate : 0.000025 loss : 0.458717 +[13:09:32.232] iteration 54800 [558.00 sec]: learning rate : 0.000025 loss : 0.619850 +[13:10:59.162] iteration 54900 [644.93 sec]: learning rate : 0.000025 loss : 0.703962 +[13:12:26.018] iteration 55000 [731.79 sec]: learning rate : 0.000025 loss : 0.669866 +[13:13:52.939] iteration 55100 [818.70 sec]: learning rate : 0.000025 loss : 0.489534 +[13:15:19.782] iteration 55200 [905.55 sec]: learning rate : 0.000025 loss : 0.655757 +[13:16:46.666] iteration 55300 [992.43 sec]: learning rate : 0.000025 loss : 0.632504 +[13:18:13.588] iteration 55400 [1079.35 sec]: learning rate : 0.000025 loss : 0.510526 +[13:19:40.437] iteration 55500 [1166.20 sec]: learning rate : 0.000025 loss : 0.695143 +[13:21:07.317] iteration 55600 [1253.08 sec]: learning rate : 0.000025 loss : 0.631150 +[13:22:34.220] iteration 55700 [1339.98 sec]: learning rate : 0.000025 loss : 0.546966 +[13:24:01.057] iteration 55800 [1426.82 sec]: learning rate : 0.000025 loss : 0.453046 +[13:25:27.956] iteration 55900 [1513.72 sec]: learning rate : 0.000025 loss : 0.383575 +[13:26:54.802] iteration 56000 [1600.57 sec]: learning rate : 0.000025 loss : 0.545284 +[13:28:21.704] iteration 56100 [1687.47 sec]: learning rate : 0.000025 loss : 0.646227 +[13:29:48.579] iteration 56200 [1774.34 sec]: learning rate : 0.000025 loss : 0.755428 +[13:30:24.164] Epoch 26 Evaluation: +[13:31:14.025] average MSE: 0.05529506877064705 average PSNR: 28.183829126118212 average SSIM: 0.6004286331784331 +[13:32:05.482] iteration 56300 [51.39 sec]: learning rate : 0.000025 loss : 0.505801 +[13:33:32.390] iteration 56400 [138.30 sec]: learning rate : 0.000025 loss : 0.611924 +[13:34:59.177] iteration 56500 [225.09 sec]: learning rate : 0.000025 loss : 0.428142 +[13:36:26.033] iteration 56600 [311.95 sec]: learning rate : 0.000025 loss : 0.795283 +[13:37:52.890] iteration 56700 [398.80 sec]: learning rate : 0.000025 loss : 1.169131 +[13:39:19.668] iteration 56800 [485.58 sec]: learning rate : 0.000025 loss : 0.746190 +[13:40:46.529] iteration 56900 [572.46 sec]: learning rate : 0.000025 loss : 0.346044 +[13:42:13.362] iteration 57000 [659.27 sec]: learning rate : 0.000025 loss : 0.709040 +[13:43:40.153] iteration 57100 [746.07 sec]: learning rate : 0.000025 loss : 0.661705 +[13:45:06.980] iteration 57200 [832.89 sec]: learning rate : 0.000025 loss : 0.785549 +[13:46:33.792] iteration 57300 [919.70 sec]: learning rate : 0.000025 loss : 0.733119 +[13:48:00.538] iteration 57400 [1006.45 sec]: learning rate : 0.000025 loss : 0.621883 +[13:49:27.359] iteration 57500 [1093.27 sec]: learning rate : 0.000025 loss : 0.768705 +[13:50:54.176] iteration 57600 [1180.09 sec]: learning rate : 0.000025 loss : 0.562918 +[13:52:20.942] iteration 57700 [1266.85 sec]: learning rate : 0.000025 loss : 0.536723 +[13:53:47.766] iteration 57800 [1353.68 sec]: learning rate : 0.000025 loss : 0.499881 +[13:55:14.574] iteration 57900 [1440.49 sec]: learning rate : 0.000025 loss : 0.583919 +[13:56:41.336] iteration 58000 [1527.25 sec]: learning rate : 0.000025 loss : 0.542974 +[13:58:08.172] iteration 58100 [1614.08 sec]: learning rate : 0.000025 loss : 0.509875 +[13:59:34.945] iteration 58200 [1700.86 sec]: learning rate : 0.000025 loss : 0.697573 +[14:01:01.779] iteration 58300 [1787.69 sec]: learning rate : 0.000025 loss : 0.466376 +[14:01:22.566] Epoch 27 Evaluation: +[14:02:13.963] average MSE: 0.055048245936632156 average PSNR: 28.208108188139786 average SSIM: 0.601719955158207 +[14:03:20.261] iteration 58400 [66.24 sec]: learning rate : 0.000025 loss : 0.391233 +[14:04:47.058] iteration 58500 [153.03 sec]: learning rate : 0.000025 loss : 0.232908 +[14:06:13.885] iteration 58600 [239.86 sec]: learning rate : 0.000025 loss : 0.553849 +[14:07:40.705] iteration 58700 [326.68 sec]: learning rate : 0.000025 loss : 0.839748 +[14:09:07.467] iteration 58800 [413.44 sec]: learning rate : 0.000025 loss : 0.580039 +[14:10:34.259] iteration 58900 [500.23 sec]: learning rate : 0.000025 loss : 0.665243 +[14:12:01.095] iteration 59000 [587.07 sec]: learning rate : 0.000025 loss : 0.313107 +[14:13:27.891] iteration 59100 [673.86 sec]: learning rate : 0.000025 loss : 0.897596 +[14:14:54.743] iteration 59200 [760.72 sec]: learning rate : 0.000025 loss : 0.449837 +[14:16:21.596] iteration 59300 [847.57 sec]: learning rate : 0.000025 loss : 0.641069 +[14:17:48.391] iteration 59400 [934.36 sec]: learning rate : 0.000025 loss : 0.999529 +[14:19:15.240] iteration 59500 [1021.22 sec]: learning rate : 0.000025 loss : 0.496081 +[14:20:42.032] iteration 59600 [1108.01 sec]: learning rate : 0.000025 loss : 0.927622 +[14:22:08.856] iteration 59700 [1194.83 sec]: learning rate : 0.000025 loss : 0.642696 +[14:23:35.720] iteration 59800 [1281.69 sec]: learning rate : 0.000025 loss : 0.440472 +[14:25:02.510] iteration 59900 [1368.48 sec]: learning rate : 0.000025 loss : 0.633221 +[14:26:29.377] iteration 60000 [1455.35 sec]: learning rate : 0.000006 loss : 0.426288 +[14:26:29.544] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_60000.pth +[14:27:56.374] iteration 60100 [1542.35 sec]: learning rate : 0.000013 loss : 0.357040 +[14:29:23.171] iteration 60200 [1629.15 sec]: learning rate : 0.000013 loss : 0.481221 +[14:30:50.042] iteration 60300 [1716.02 sec]: learning rate : 0.000013 loss : 0.542243 +[14:32:16.941] iteration 60400 [1802.92 sec]: learning rate : 0.000013 loss : 0.763359 +[14:32:23.003] Epoch 28 Evaluation: +[14:33:13.052] average MSE: 0.05513148382306099 average PSNR: 28.19997840951861 average SSIM: 0.6016353317273379 +[14:34:33.999] iteration 60500 [80.88 sec]: learning rate : 0.000013 loss : 0.988284 +[14:36:00.892] iteration 60600 [167.78 sec]: learning rate : 0.000013 loss : 0.552108 +[14:37:27.720] iteration 60700 [254.60 sec]: learning rate : 0.000013 loss : 0.395630 +[14:38:54.493] iteration 60800 [341.38 sec]: learning rate : 0.000013 loss : 0.525593 +[14:40:21.329] iteration 60900 [428.21 sec]: learning rate : 0.000013 loss : 0.786295 +[14:41:48.194] iteration 61000 [515.08 sec]: learning rate : 0.000013 loss : 0.526301 +[14:43:14.995] iteration 61100 [601.88 sec]: learning rate : 0.000013 loss : 0.274889 +[14:44:41.836] iteration 61200 [688.72 sec]: learning rate : 0.000013 loss : 0.498063 +[14:46:08.680] iteration 61300 [775.56 sec]: learning rate : 0.000013 loss : 0.439417 +[14:47:35.453] iteration 61400 [862.34 sec]: learning rate : 0.000013 loss : 0.552196 +[14:49:02.256] iteration 61500 [949.14 sec]: learning rate : 0.000013 loss : 0.700743 +[14:50:29.045] iteration 61600 [1035.93 sec]: learning rate : 0.000013 loss : 0.775923 +[14:51:55.877] iteration 61700 [1122.76 sec]: learning rate : 0.000013 loss : 0.586316 +[14:53:22.673] iteration 61800 [1209.56 sec]: learning rate : 0.000013 loss : 0.397631 +[14:54:49.467] iteration 61900 [1296.35 sec]: learning rate : 0.000013 loss : 0.608663 +[14:56:16.327] iteration 62000 [1383.21 sec]: learning rate : 0.000013 loss : 0.438841 +[14:57:43.136] iteration 62100 [1470.02 sec]: learning rate : 0.000013 loss : 0.528166 +[14:59:09.961] iteration 62200 [1556.85 sec]: learning rate : 0.000013 loss : 0.409516 +[15:00:36.784] iteration 62300 [1643.67 sec]: learning rate : 0.000013 loss : 0.428180 +[15:02:03.561] iteration 62400 [1730.45 sec]: learning rate : 0.000013 loss : 0.852455 +[15:03:21.693] Epoch 29 Evaluation: +[15:04:11.990] average MSE: 0.05489760637283325 average PSNR: 28.22851315128196 average SSIM: 0.6022388207263344 +[15:04:20.916] iteration 62500 [8.86 sec]: learning rate : 0.000013 loss : 0.551154 +[15:05:47.805] iteration 62600 [95.75 sec]: learning rate : 0.000013 loss : 0.947370 +[15:07:14.582] iteration 62700 [182.53 sec]: learning rate : 0.000013 loss : 0.683737 +[15:08:41.473] iteration 62800 [269.42 sec]: learning rate : 0.000013 loss : 0.478317 +[15:10:08.269] iteration 62900 [356.22 sec]: learning rate : 0.000013 loss : 0.572866 +[15:11:35.151] iteration 63000 [443.10 sec]: learning rate : 0.000013 loss : 0.525382 +[15:13:02.028] iteration 63100 [529.97 sec]: learning rate : 0.000013 loss : 0.584224 +[15:14:28.853] iteration 63200 [616.80 sec]: learning rate : 0.000013 loss : 0.692159 +[15:15:55.703] iteration 63300 [703.65 sec]: learning rate : 0.000013 loss : 0.566426 +[15:17:22.561] iteration 63400 [790.51 sec]: learning rate : 0.000013 loss : 0.533737 +[15:18:49.370] iteration 63500 [877.32 sec]: learning rate : 0.000013 loss : 0.570705 +[15:20:16.240] iteration 63600 [964.19 sec]: learning rate : 0.000013 loss : 0.692240 +[15:21:43.034] iteration 63700 [1050.98 sec]: learning rate : 0.000013 loss : 0.587613 +[15:23:09.877] iteration 63800 [1137.82 sec]: learning rate : 0.000013 loss : 0.624004 +[15:24:36.726] iteration 63900 [1224.67 sec]: learning rate : 0.000013 loss : 0.557426 +[15:26:03.593] iteration 64000 [1311.54 sec]: learning rate : 0.000013 loss : 0.701765 +[15:27:30.485] iteration 64100 [1398.43 sec]: learning rate : 0.000013 loss : 0.791001 +[15:28:57.369] iteration 64200 [1485.32 sec]: learning rate : 0.000013 loss : 1.121631 +[15:30:24.197] iteration 64300 [1572.14 sec]: learning rate : 0.000013 loss : 0.874008 +[15:31:51.080] iteration 64400 [1659.03 sec]: learning rate : 0.000013 loss : 0.822773 +[15:33:17.961] iteration 64500 [1745.91 sec]: learning rate : 0.000013 loss : 0.676041 +[15:34:21.294] Epoch 30 Evaluation: +[15:35:11.676] average MSE: 0.054817233234643936 average PSNR: 28.233001966697973 average SSIM: 0.6026884514930516 +[15:35:35.345] iteration 64600 [23.61 sec]: learning rate : 0.000013 loss : 0.641428 +[15:37:02.266] iteration 64700 [110.53 sec]: learning rate : 0.000013 loss : 0.516631 +[15:38:29.111] iteration 64800 [197.37 sec]: learning rate : 0.000013 loss : 0.593694 +[15:39:55.933] iteration 64900 [284.20 sec]: learning rate : 0.000013 loss : 0.743533 +[15:41:22.825] iteration 65000 [371.09 sec]: learning rate : 0.000013 loss : 0.605634 +[15:42:49.678] iteration 65100 [457.94 sec]: learning rate : 0.000013 loss : 0.522957 +[15:44:16.481] iteration 65200 [544.74 sec]: learning rate : 0.000013 loss : 0.510189 +[15:45:43.307] iteration 65300 [631.57 sec]: learning rate : 0.000013 loss : 0.544453 +[15:47:10.195] iteration 65400 [718.46 sec]: learning rate : 0.000013 loss : 0.490976 +[15:48:37.007] iteration 65500 [805.27 sec]: learning rate : 0.000013 loss : 0.631909 +[15:50:03.876] iteration 65600 [892.14 sec]: learning rate : 0.000013 loss : 0.723711 +[15:51:30.731] iteration 65700 [978.99 sec]: learning rate : 0.000013 loss : 0.542940 +[15:52:57.525] iteration 65800 [1065.79 sec]: learning rate : 0.000013 loss : 0.500807 +[15:54:24.365] iteration 65900 [1152.63 sec]: learning rate : 0.000013 loss : 0.926657 +[15:55:51.237] iteration 66000 [1239.50 sec]: learning rate : 0.000013 loss : 0.625619 +[15:57:18.066] iteration 66100 [1326.33 sec]: learning rate : 0.000013 loss : 0.589913 +[15:58:44.901] iteration 66200 [1413.16 sec]: learning rate : 0.000013 loss : 0.523205 +[16:00:11.772] iteration 66300 [1500.03 sec]: learning rate : 0.000013 loss : 0.660805 +[16:01:38.580] iteration 66400 [1586.84 sec]: learning rate : 0.000013 loss : 0.664917 +[16:03:05.452] iteration 66500 [1673.71 sec]: learning rate : 0.000013 loss : 0.485555 +[16:04:32.328] iteration 66600 [1760.59 sec]: learning rate : 0.000013 loss : 0.560611 +[16:05:20.928] Epoch 31 Evaluation: +[16:06:10.946] average MSE: 0.05482204258441925 average PSNR: 28.229282488551718 average SSIM: 0.6023792541990602 +[16:06:49.388] iteration 66700 [38.38 sec]: learning rate : 0.000013 loss : 1.009463 +[16:08:16.297] iteration 66800 [125.29 sec]: learning rate : 0.000013 loss : 0.671146 +[16:09:43.112] iteration 66900 [212.10 sec]: learning rate : 0.000013 loss : 0.381783 +[16:11:09.993] iteration 67000 [298.98 sec]: learning rate : 0.000013 loss : 0.579893 +[16:12:36.940] iteration 67100 [385.95 sec]: learning rate : 0.000013 loss : 0.420612 +[16:14:03.787] iteration 67200 [472.78 sec]: learning rate : 0.000013 loss : 0.648703 +[16:15:30.678] iteration 67300 [559.67 sec]: learning rate : 0.000013 loss : 0.704861 +[16:16:57.539] iteration 67400 [646.53 sec]: learning rate : 0.000013 loss : 0.667168 +[16:18:24.491] iteration 67500 [733.48 sec]: learning rate : 0.000013 loss : 0.946706 +[16:19:51.417] iteration 67600 [820.41 sec]: learning rate : 0.000013 loss : 0.588822 +[16:21:18.272] iteration 67700 [907.26 sec]: learning rate : 0.000013 loss : 0.680351 +[16:22:45.201] iteration 67800 [994.19 sec]: learning rate : 0.000013 loss : 0.601477 +[16:24:12.138] iteration 67900 [1081.13 sec]: learning rate : 0.000013 loss : 0.442984 +[16:25:39.014] iteration 68000 [1168.00 sec]: learning rate : 0.000013 loss : 0.644414 +[16:27:05.977] iteration 68100 [1254.97 sec]: learning rate : 0.000013 loss : 0.638621 +[16:28:32.878] iteration 68200 [1341.87 sec]: learning rate : 0.000013 loss : 0.531606 +[16:29:59.818] iteration 68300 [1428.81 sec]: learning rate : 0.000013 loss : 0.292360 +[16:31:26.722] iteration 68400 [1515.71 sec]: learning rate : 0.000013 loss : 0.198762 +[16:32:53.599] iteration 68500 [1602.59 sec]: learning rate : 0.000013 loss : 0.592542 +[16:34:20.551] iteration 68600 [1689.54 sec]: learning rate : 0.000013 loss : 0.652091 +[16:35:47.542] iteration 68700 [1776.55 sec]: learning rate : 0.000013 loss : 0.665225 +[16:36:21.397] Epoch 32 Evaluation: +[16:37:11.712] average MSE: 0.05482378602027893 average PSNR: 28.238472660532153 average SSIM: 0.6029254681998458 +[16:38:04.914] iteration 68800 [53.14 sec]: learning rate : 0.000013 loss : 0.536492 +[16:39:31.900] iteration 68900 [140.12 sec]: learning rate : 0.000013 loss : 0.705845 +[16:40:58.811] iteration 69000 [227.10 sec]: learning rate : 0.000013 loss : 0.769871 +[16:42:25.706] iteration 69100 [313.93 sec]: learning rate : 0.000013 loss : 0.419459 +[16:43:52.627] iteration 69200 [400.85 sec]: learning rate : 0.000013 loss : 0.806476 +[16:45:19.568] iteration 69300 [487.79 sec]: learning rate : 0.000013 loss : 0.407657 +[16:46:46.438] iteration 69400 [574.66 sec]: learning rate : 0.000013 loss : 0.406011 +[16:48:13.386] iteration 69500 [661.61 sec]: learning rate : 0.000013 loss : 0.416050 +[16:49:40.314] iteration 69600 [748.54 sec]: learning rate : 0.000013 loss : 0.732661 +[16:51:07.170] iteration 69700 [835.39 sec]: learning rate : 0.000013 loss : 0.664179 +[16:52:34.084] iteration 69800 [922.31 sec]: learning rate : 0.000013 loss : 0.784161 +[16:54:00.931] iteration 69900 [1009.16 sec]: learning rate : 0.000013 loss : 0.660076 +[16:55:27.777] iteration 70000 [1096.00 sec]: learning rate : 0.000013 loss : 0.814639 +[16:56:54.688] iteration 70100 [1182.91 sec]: learning rate : 0.000013 loss : 0.665988 +[16:58:21.541] iteration 70200 [1269.77 sec]: learning rate : 0.000013 loss : 0.393253 +[16:59:48.432] iteration 70300 [1356.66 sec]: learning rate : 0.000013 loss : 0.564947 +[17:01:15.301] iteration 70400 [1443.53 sec]: learning rate : 0.000013 loss : 0.913827 +[17:02:42.228] iteration 70500 [1530.45 sec]: learning rate : 0.000013 loss : 0.326478 +[17:04:09.117] iteration 70600 [1617.34 sec]: learning rate : 0.000013 loss : 0.332409 +[17:05:35.983] iteration 70700 [1704.21 sec]: learning rate : 0.000013 loss : 0.510290 +[17:07:02.900] iteration 70800 [1791.12 sec]: learning rate : 0.000013 loss : 0.567169 +[17:07:21.971] Epoch 33 Evaluation: +[17:08:12.826] average MSE: 0.054686080664396286 average PSNR: 28.248551840094333 average SSIM: 0.6021107342881493 +[17:09:20.786] iteration 70900 [67.90 sec]: learning rate : 0.000013 loss : 0.583598 +[17:10:47.750] iteration 71000 [154.86 sec]: learning rate : 0.000013 loss : 1.010447 +[17:12:14.647] iteration 71100 [241.76 sec]: learning rate : 0.000013 loss : 0.664574 +[17:13:41.517] iteration 71200 [328.63 sec]: learning rate : 0.000013 loss : 0.637764 +[17:15:08.431] iteration 71300 [415.54 sec]: learning rate : 0.000013 loss : 0.688340 +[17:16:35.346] iteration 71400 [502.46 sec]: learning rate : 0.000013 loss : 0.443358 +[17:18:02.211] iteration 71500 [589.32 sec]: learning rate : 0.000013 loss : 0.752675 +[17:19:29.195] iteration 71600 [676.31 sec]: learning rate : 0.000013 loss : 0.644495 +[17:20:56.064] iteration 71700 [763.17 sec]: learning rate : 0.000013 loss : 0.511932 +[17:22:23.012] iteration 71800 [850.12 sec]: learning rate : 0.000013 loss : 0.879163 +[17:23:49.956] iteration 71900 [937.07 sec]: learning rate : 0.000013 loss : 0.585545 +[17:25:16.830] iteration 72000 [1023.94 sec]: learning rate : 0.000013 loss : 0.606407 +[17:26:43.780] iteration 72100 [1110.89 sec]: learning rate : 0.000013 loss : 0.408606 +[17:28:10.710] iteration 72200 [1197.82 sec]: learning rate : 0.000013 loss : 0.503974 +[17:29:37.584] iteration 72300 [1284.70 sec]: learning rate : 0.000013 loss : 0.549124 +[17:31:04.505] iteration 72400 [1371.62 sec]: learning rate : 0.000013 loss : 0.639035 +[17:32:31.468] iteration 72500 [1458.58 sec]: learning rate : 0.000013 loss : 0.810144 +[17:33:58.352] iteration 72600 [1545.46 sec]: learning rate : 0.000013 loss : 0.626607 +[17:35:25.276] iteration 72700 [1632.39 sec]: learning rate : 0.000013 loss : 0.701059 +[17:36:52.227] iteration 72800 [1719.35 sec]: learning rate : 0.000013 loss : 0.506549 +[17:38:19.089] iteration 72900 [1806.20 sec]: learning rate : 0.000013 loss : 0.492838 +[17:38:23.403] Epoch 34 Evaluation: +[17:39:13.717] average MSE: 0.054529786109924316 average PSNR: 28.26267570696865 average SSIM: 0.6032703595003429 +[17:40:36.577] iteration 73000 [82.80 sec]: learning rate : 0.000013 loss : 0.683237 +[17:42:03.421] iteration 73100 [169.64 sec]: learning rate : 0.000013 loss : 0.548331 +[17:43:30.334] iteration 73200 [256.55 sec]: learning rate : 0.000013 loss : 0.546970 +[17:44:57.296] iteration 73300 [343.52 sec]: learning rate : 0.000013 loss : 0.449644 +[17:46:24.175] iteration 73400 [430.39 sec]: learning rate : 0.000013 loss : 0.497707 +[17:47:51.123] iteration 73500 [517.35 sec]: learning rate : 0.000013 loss : 0.951446 +[17:49:18.060] iteration 73600 [604.28 sec]: learning rate : 0.000013 loss : 0.564296 +[17:50:44.935] iteration 73700 [691.16 sec]: learning rate : 0.000013 loss : 0.646891 +[17:52:11.877] iteration 73800 [778.10 sec]: learning rate : 0.000013 loss : 0.459015 +[17:53:38.762] iteration 73900 [864.98 sec]: learning rate : 0.000013 loss : 0.795680 +[17:55:05.701] iteration 74000 [951.92 sec]: learning rate : 0.000013 loss : 0.580883 +[17:56:32.648] iteration 74100 [1038.87 sec]: learning rate : 0.000013 loss : 0.864461 +[17:57:59.541] iteration 74200 [1125.76 sec]: learning rate : 0.000013 loss : 0.607161 +[17:59:26.493] iteration 74300 [1212.71 sec]: learning rate : 0.000013 loss : 0.624041 +[18:00:53.471] iteration 74400 [1299.69 sec]: learning rate : 0.000013 loss : 0.454513 +[18:02:20.343] iteration 74500 [1386.56 sec]: learning rate : 0.000013 loss : 0.558715 +[18:03:47.301] iteration 74600 [1473.52 sec]: learning rate : 0.000013 loss : 0.633678 +[18:05:14.269] iteration 74700 [1560.49 sec]: learning rate : 0.000013 loss : 0.699301 +[18:06:41.166] iteration 74800 [1647.39 sec]: learning rate : 0.000013 loss : 0.443320 +[18:08:08.126] iteration 74900 [1734.34 sec]: learning rate : 0.000013 loss : 0.521007 +[18:09:24.645] Epoch 35 Evaluation: +[18:10:17.471] average MSE: 0.05466415360569954 average PSNR: 28.248631562888903 average SSIM: 0.6024604948448192 +[18:10:28.152] iteration 75000 [10.62 sec]: learning rate : 0.000013 loss : 0.437117 +[18:11:55.023] iteration 75100 [97.49 sec]: learning rate : 0.000013 loss : 0.594819 +[18:13:21.991] iteration 75200 [184.46 sec]: learning rate : 0.000013 loss : 0.575674 +[18:14:48.884] iteration 75300 [271.35 sec]: learning rate : 0.000013 loss : 0.757830 +[18:16:15.838] iteration 75400 [358.30 sec]: learning rate : 0.000013 loss : 0.626880 +[18:17:42.785] iteration 75500 [445.25 sec]: learning rate : 0.000013 loss : 0.478698 +[18:19:09.712] iteration 75600 [532.18 sec]: learning rate : 0.000013 loss : 0.782661 +[18:20:36.724] iteration 75700 [619.19 sec]: learning rate : 0.000013 loss : 0.448108 +[18:22:03.657] iteration 75800 [706.12 sec]: learning rate : 0.000013 loss : 0.826254 +[18:23:30.618] iteration 75900 [793.08 sec]: learning rate : 0.000013 loss : 0.672512 +[18:24:57.608] iteration 76000 [880.07 sec]: learning rate : 0.000013 loss : 0.495225 +[18:26:24.530] iteration 76100 [967.00 sec]: learning rate : 0.000013 loss : 0.407617 +[18:27:51.492] iteration 76200 [1053.96 sec]: learning rate : 0.000013 loss : 0.625638 +[18:29:18.451] iteration 76300 [1140.92 sec]: learning rate : 0.000013 loss : 0.527758 +[18:30:45.441] iteration 76400 [1227.91 sec]: learning rate : 0.000013 loss : 0.675954 +[18:32:12.400] iteration 76500 [1314.87 sec]: learning rate : 0.000013 loss : 0.673356 +[18:33:39.357] iteration 76600 [1401.82 sec]: learning rate : 0.000013 loss : 0.542466 +[18:35:06.376] iteration 76700 [1488.84 sec]: learning rate : 0.000013 loss : 0.754163 +[18:36:33.387] iteration 76800 [1575.85 sec]: learning rate : 0.000013 loss : 0.690745 +[18:38:00.310] iteration 76900 [1662.78 sec]: learning rate : 0.000013 loss : 0.634045 +[18:39:27.311] iteration 77000 [1749.78 sec]: learning rate : 0.000013 loss : 0.555658 +[18:40:29.012] Epoch 36 Evaluation: +[18:41:21.468] average MSE: 0.054672494530677795 average PSNR: 28.250387374128916 average SSIM: 0.6032694922952092 +[18:41:46.926] iteration 77100 [25.40 sec]: learning rate : 0.000013 loss : 0.530015 +[18:43:13.971] iteration 77200 [112.44 sec]: learning rate : 0.000013 loss : 0.618906 +[18:44:40.957] iteration 77300 [199.43 sec]: learning rate : 0.000013 loss : 0.459458 +[18:46:07.886] iteration 77400 [286.35 sec]: learning rate : 0.000013 loss : 0.824724 +[18:47:34.858] iteration 77500 [373.33 sec]: learning rate : 0.000013 loss : 0.660993 +[18:49:01.789] iteration 77600 [460.26 sec]: learning rate : 0.000013 loss : 0.528871 +[18:50:28.785] iteration 77700 [547.25 sec]: learning rate : 0.000013 loss : 0.788358 +[18:51:55.797] iteration 77800 [634.27 sec]: learning rate : 0.000013 loss : 0.639068 +[18:53:22.732] iteration 77900 [721.20 sec]: learning rate : 0.000013 loss : 0.870011 +[18:54:49.710] iteration 78000 [808.18 sec]: learning rate : 0.000013 loss : 0.486791 +[18:56:16.684] iteration 78100 [895.15 sec]: learning rate : 0.000013 loss : 0.616822 +[18:57:43.592] iteration 78200 [982.06 sec]: learning rate : 0.000013 loss : 0.613874 +[18:59:10.531] iteration 78300 [1069.00 sec]: learning rate : 0.000013 loss : 0.511961 +[19:00:37.459] iteration 78400 [1155.93 sec]: learning rate : 0.000013 loss : 0.517099 +[19:02:04.450] iteration 78500 [1242.92 sec]: learning rate : 0.000013 loss : 0.827525 +[19:03:31.442] iteration 78600 [1329.91 sec]: learning rate : 0.000013 loss : 0.604990 +[19:04:58.381] iteration 78700 [1416.85 sec]: learning rate : 0.000013 loss : 0.556006 +[19:06:25.371] iteration 78800 [1503.84 sec]: learning rate : 0.000013 loss : 0.764828 +[19:07:52.341] iteration 78900 [1590.81 sec]: learning rate : 0.000013 loss : 0.618354 +[19:09:19.236] iteration 79000 [1677.71 sec]: learning rate : 0.000013 loss : 0.590379 +[19:10:46.241] iteration 79100 [1764.71 sec]: learning rate : 0.000013 loss : 0.586044 +[19:11:33.133] Epoch 37 Evaluation: +[19:12:25.187] average MSE: 0.05456169322133064 average PSNR: 28.264027595703343 average SSIM: 0.6026332300819792 +[19:13:05.520] iteration 79200 [40.27 sec]: learning rate : 0.000013 loss : 0.511329 +[19:14:32.402] iteration 79300 [127.15 sec]: learning rate : 0.000013 loss : 0.790596 +[19:15:59.348] iteration 79400 [214.10 sec]: learning rate : 0.000013 loss : 0.604448 +[19:17:26.302] iteration 79500 [301.05 sec]: learning rate : 0.000013 loss : 0.528919 +[19:18:53.205] iteration 79600 [387.96 sec]: learning rate : 0.000013 loss : 0.596911 +[19:20:20.165] iteration 79700 [474.91 sec]: learning rate : 0.000013 loss : 0.782642 +[19:21:47.134] iteration 79800 [561.88 sec]: learning rate : 0.000013 loss : 0.742556 +[19:23:14.041] iteration 79900 [648.79 sec]: learning rate : 0.000013 loss : 0.608868 +[19:24:41.018] iteration 80000 [735.77 sec]: learning rate : 0.000003 loss : 0.865048 +[19:24:41.177] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_80000.pth +[19:26:08.098] iteration 80100 [822.85 sec]: learning rate : 0.000006 loss : 0.562691 +[19:27:34.984] iteration 80200 [909.73 sec]: learning rate : 0.000006 loss : 0.664300 +[19:29:01.974] iteration 80300 [996.72 sec]: learning rate : 0.000006 loss : 0.521210 +[19:30:28.865] iteration 80400 [1083.61 sec]: learning rate : 0.000006 loss : 0.702085 +[19:31:55.811] iteration 80500 [1170.56 sec]: learning rate : 0.000006 loss : 0.811921 +[19:33:22.790] iteration 80600 [1257.54 sec]: learning rate : 0.000006 loss : 0.700604 +[19:34:49.703] iteration 80700 [1344.45 sec]: learning rate : 0.000006 loss : 0.641042 +[19:36:16.667] iteration 80800 [1431.42 sec]: learning rate : 0.000006 loss : 0.670426 +[19:37:43.559] iteration 80900 [1518.33 sec]: learning rate : 0.000006 loss : 0.609019 +[19:39:10.557] iteration 81000 [1605.31 sec]: learning rate : 0.000006 loss : 0.732523 +[19:40:37.499] iteration 81100 [1692.25 sec]: learning rate : 0.000006 loss : 0.848621 +[19:42:04.373] iteration 81200 [1779.12 sec]: learning rate : 0.000006 loss : 0.767695 +[19:42:36.534] Epoch 38 Evaluation: +[19:43:26.931] average MSE: 0.0544360876083374 average PSNR: 28.27312191082898 average SSIM: 0.6032143461058076 +[19:44:21.915] iteration 81300 [54.92 sec]: learning rate : 0.000006 loss : 0.613368 +[19:45:48.926] iteration 81400 [141.93 sec]: learning rate : 0.000006 loss : 0.679548 +[19:47:15.819] iteration 81500 [228.83 sec]: learning rate : 0.000006 loss : 0.453332 +[19:48:42.764] iteration 81600 [315.77 sec]: learning rate : 0.000006 loss : 0.595009 +[19:50:09.664] iteration 81700 [402.67 sec]: learning rate : 0.000006 loss : 0.717682 +[19:51:36.652] iteration 81800 [489.66 sec]: learning rate : 0.000006 loss : 0.350242 +[19:53:03.653] iteration 81900 [576.66 sec]: learning rate : 0.000006 loss : 0.884562 +[19:54:30.561] iteration 82000 [663.57 sec]: learning rate : 0.000006 loss : 0.406053 +[19:55:57.553] iteration 82100 [750.56 sec]: learning rate : 0.000006 loss : 0.550120 +[19:57:24.456] iteration 82200 [837.46 sec]: learning rate : 0.000006 loss : 0.766344 +[19:58:51.433] iteration 82300 [924.44 sec]: learning rate : 0.000006 loss : 0.441294 +[20:00:18.400] iteration 82400 [1011.41 sec]: learning rate : 0.000006 loss : 0.603833 +[20:01:45.304] iteration 82500 [1098.31 sec]: learning rate : 0.000006 loss : 0.572126 +[20:03:12.303] iteration 82600 [1185.31 sec]: learning rate : 0.000006 loss : 0.790072 +[20:04:39.298] iteration 82700 [1272.30 sec]: learning rate : 0.000006 loss : 0.546519 +[20:06:06.230] iteration 82800 [1359.24 sec]: learning rate : 0.000006 loss : 0.515622 +[20:07:33.176] iteration 82900 [1446.18 sec]: learning rate : 0.000006 loss : 0.895252 +[20:09:00.097] iteration 83000 [1533.10 sec]: learning rate : 0.000006 loss : 0.646662 +[20:10:27.154] iteration 83100 [1620.16 sec]: learning rate : 0.000006 loss : 0.420440 +[20:11:54.096] iteration 83200 [1707.10 sec]: learning rate : 0.000006 loss : 0.242081 +[20:13:21.014] iteration 83300 [1794.02 sec]: learning rate : 0.000006 loss : 0.722668 +[20:13:38.377] Epoch 39 Evaluation: +[20:14:28.513] average MSE: 0.05449827387928963 average PSNR: 28.269709999935742 average SSIM: 0.6035322639830764 +[20:15:38.414] iteration 83400 [69.84 sec]: learning rate : 0.000006 loss : 0.667198 +[20:17:05.394] iteration 83500 [156.82 sec]: learning rate : 0.000006 loss : 0.791940 +[20:18:32.312] iteration 83600 [243.74 sec]: learning rate : 0.000006 loss : 0.506514 +[20:19:59.319] iteration 83700 [330.74 sec]: learning rate : 0.000006 loss : 0.693759 +[20:21:26.251] iteration 83800 [417.68 sec]: learning rate : 0.000006 loss : 0.770283 +[20:22:53.209] iteration 83900 [504.63 sec]: learning rate : 0.000006 loss : 0.503003 +[20:24:20.217] iteration 84000 [591.64 sec]: learning rate : 0.000006 loss : 0.561854 +[20:25:47.157] iteration 84100 [678.58 sec]: learning rate : 0.000006 loss : 0.521698 +[20:27:14.141] iteration 84200 [765.56 sec]: learning rate : 0.000006 loss : 0.459057 +[20:28:41.062] iteration 84300 [852.49 sec]: learning rate : 0.000006 loss : 0.547077 +[20:30:08.082] iteration 84400 [939.51 sec]: learning rate : 0.000006 loss : 0.498483 +[20:31:35.066] iteration 84500 [1026.49 sec]: learning rate : 0.000006 loss : 0.496040 +[20:33:02.009] iteration 84600 [1113.43 sec]: learning rate : 0.000006 loss : 0.566885 +[20:34:29.016] iteration 84700 [1200.44 sec]: learning rate : 0.000006 loss : 0.468404 +[20:35:55.957] iteration 84800 [1287.38 sec]: learning rate : 0.000006 loss : 0.746009 +[20:37:22.929] iteration 84900 [1374.35 sec]: learning rate : 0.000006 loss : 0.763690 +[20:38:49.933] iteration 85000 [1461.36 sec]: learning rate : 0.000006 loss : 0.601398 +[20:40:16.884] iteration 85100 [1548.31 sec]: learning rate : 0.000006 loss : 0.401801 +[20:41:43.837] iteration 85200 [1635.26 sec]: learning rate : 0.000006 loss : 0.750923 +[20:43:10.800] iteration 85300 [1722.29 sec]: learning rate : 0.000006 loss : 0.412610 +[20:44:37.706] iteration 85400 [1809.13 sec]: learning rate : 0.000006 loss : 0.700627 +[20:44:40.298] Epoch 40 Evaluation: +[20:45:31.340] average MSE: 0.05453088879585266 average PSNR: 28.27015705151995 average SSIM: 0.6039861729693019 +[20:46:55.970] iteration 85500 [84.57 sec]: learning rate : 0.000006 loss : 0.395737 +[20:48:22.850] iteration 85600 [171.45 sec]: learning rate : 0.000006 loss : 0.673076 +[20:49:49.808] iteration 85700 [258.40 sec]: learning rate : 0.000006 loss : 0.703743 +[20:51:16.811] iteration 85800 [345.41 sec]: learning rate : 0.000006 loss : 0.568968 +[20:52:43.719] iteration 85900 [432.32 sec]: learning rate : 0.000006 loss : 0.525585 +[20:54:10.703] iteration 86000 [519.30 sec]: learning rate : 0.000006 loss : 0.810662 +[20:55:37.686] iteration 86100 [606.28 sec]: learning rate : 0.000006 loss : 0.686098 +[20:57:04.615] iteration 86200 [693.21 sec]: learning rate : 0.000006 loss : 0.550659 +[20:58:31.576] iteration 86300 [780.17 sec]: learning rate : 0.000006 loss : 0.536103 +[20:59:58.539] iteration 86400 [867.14 sec]: learning rate : 0.000006 loss : 0.784032 +[21:01:25.471] iteration 86500 [954.07 sec]: learning rate : 0.000006 loss : 0.482069 +[21:02:52.435] iteration 86600 [1041.03 sec]: learning rate : 0.000006 loss : 0.470734 +[21:04:19.385] iteration 86700 [1127.98 sec]: learning rate : 0.000006 loss : 0.591539 +[21:05:46.299] iteration 86800 [1214.90 sec]: learning rate : 0.000006 loss : 0.903680 +[21:07:13.253] iteration 86900 [1301.85 sec]: learning rate : 0.000006 loss : 0.593872 +[21:08:40.224] iteration 87000 [1388.82 sec]: learning rate : 0.000006 loss : 0.542740 +[21:10:07.173] iteration 87100 [1475.77 sec]: learning rate : 0.000006 loss : 0.418116 +[21:11:34.134] iteration 87200 [1562.73 sec]: learning rate : 0.000006 loss : 0.491435 +[21:13:01.097] iteration 87300 [1649.69 sec]: learning rate : 0.000006 loss : 0.543590 +[21:14:27.989] iteration 87400 [1736.59 sec]: learning rate : 0.000006 loss : 0.379154 +[21:15:42.772] Epoch 41 Evaluation: +[21:16:34.825] average MSE: 0.0545741431415081 average PSNR: 28.26110482855388 average SSIM: 0.6033126399682456 +[21:16:47.265] iteration 87500 [12.38 sec]: learning rate : 0.000006 loss : 0.730692 +[21:18:14.284] iteration 87600 [99.39 sec]: learning rate : 0.000006 loss : 0.356822 +[21:19:41.179] iteration 87700 [186.29 sec]: learning rate : 0.000006 loss : 0.742943 +[21:21:08.154] iteration 87800 [273.27 sec]: learning rate : 0.000006 loss : 0.501430 +[21:22:35.146] iteration 87900 [360.26 sec]: learning rate : 0.000006 loss : 0.851021 +[21:24:02.041] iteration 88000 [447.15 sec]: learning rate : 0.000006 loss : 0.363459 +[21:25:28.991] iteration 88100 [534.10 sec]: learning rate : 0.000006 loss : 0.670330 +[21:26:56.012] iteration 88200 [621.12 sec]: learning rate : 0.000006 loss : 0.423284 +[21:28:22.948] iteration 88300 [708.06 sec]: learning rate : 0.000006 loss : 0.693172 +[21:29:49.924] iteration 88400 [795.03 sec]: learning rate : 0.000006 loss : 0.787658 +[21:31:16.891] iteration 88500 [882.00 sec]: learning rate : 0.000006 loss : 0.650114 +[21:32:43.814] iteration 88600 [968.93 sec]: learning rate : 0.000006 loss : 0.467843 +[21:34:10.822] iteration 88700 [1055.93 sec]: learning rate : 0.000006 loss : 0.713269 +[21:35:37.797] iteration 88800 [1142.91 sec]: learning rate : 0.000006 loss : 0.903523 +[21:37:04.717] iteration 88900 [1229.83 sec]: learning rate : 0.000006 loss : 0.921937 +[21:38:31.703] iteration 89000 [1316.82 sec]: learning rate : 0.000006 loss : 0.782878 +[21:39:58.697] iteration 89100 [1403.81 sec]: learning rate : 0.000006 loss : 0.939209 +[21:41:25.641] iteration 89200 [1490.75 sec]: learning rate : 0.000006 loss : 0.503281 +[21:42:52.666] iteration 89300 [1577.78 sec]: learning rate : 0.000006 loss : 0.833292 +[21:44:19.667] iteration 89400 [1664.78 sec]: learning rate : 0.000006 loss : 0.752596 +[21:45:46.594] iteration 89500 [1751.71 sec]: learning rate : 0.000006 loss : 0.670283 +[21:46:46.567] Epoch 42 Evaluation: +[21:47:36.965] average MSE: 0.05452476069331169 average PSNR: 28.266436019965894 average SSIM: 0.6038203206664287 +[21:48:04.143] iteration 89600 [27.11 sec]: learning rate : 0.000006 loss : 0.530001 +[21:49:31.105] iteration 89700 [114.08 sec]: learning rate : 0.000006 loss : 1.018061 +[21:50:58.004] iteration 89800 [200.97 sec]: learning rate : 0.000006 loss : 0.514207 +[21:52:25.006] iteration 89900 [287.98 sec]: learning rate : 0.000006 loss : 1.299710 +[21:53:51.992] iteration 90000 [374.96 sec]: learning rate : 0.000006 loss : 0.779428 +[21:55:18.948] iteration 90100 [461.92 sec]: learning rate : 0.000006 loss : 0.495166 +[21:56:45.993] iteration 90200 [548.97 sec]: learning rate : 0.000006 loss : 0.555407 +[21:58:12.971] iteration 90300 [635.94 sec]: learning rate : 0.000006 loss : 0.400898 +[21:59:39.931] iteration 90400 [722.90 sec]: learning rate : 0.000006 loss : 1.047717 +[22:01:06.923] iteration 90500 [809.89 sec]: learning rate : 0.000006 loss : 0.600494 +[22:02:33.858] iteration 90600 [896.83 sec]: learning rate : 0.000006 loss : 0.544550 +[22:04:00.839] iteration 90700 [983.81 sec]: learning rate : 0.000006 loss : 0.645106 +[22:05:27.842] iteration 90800 [1070.81 sec]: learning rate : 0.000006 loss : 0.914053 +[22:06:54.783] iteration 90900 [1157.75 sec]: learning rate : 0.000006 loss : 0.383747 +[22:08:21.809] iteration 91000 [1244.78 sec]: learning rate : 0.000006 loss : 0.649829 +[22:09:48.771] iteration 91100 [1331.74 sec]: learning rate : 0.000006 loss : 0.524301 +[22:11:15.765] iteration 91200 [1418.74 sec]: learning rate : 0.000006 loss : 0.355092 +[22:12:42.805] iteration 91300 [1505.78 sec]: learning rate : 0.000006 loss : 0.687346 +[22:14:09.750] iteration 91400 [1592.72 sec]: learning rate : 0.000006 loss : 0.734529 +[22:15:36.770] iteration 91500 [1679.74 sec]: learning rate : 0.000006 loss : 0.586521 +[22:17:03.770] iteration 91600 [1766.74 sec]: learning rate : 0.000006 loss : 0.587880 +[22:17:48.943] Epoch 43 Evaluation: +[22:18:39.390] average MSE: 0.054386790841817856 average PSNR: 28.283645719174142 average SSIM: 0.6037610207401088 +[22:19:21.332] iteration 91700 [41.88 sec]: learning rate : 0.000006 loss : 0.418780 +[22:20:48.356] iteration 91800 [128.90 sec]: learning rate : 0.000006 loss : 0.749003 +[22:22:15.275] iteration 91900 [215.82 sec]: learning rate : 0.000006 loss : 1.016500 +[22:23:42.270] iteration 92000 [302.82 sec]: learning rate : 0.000006 loss : 0.548053 +[22:25:09.249] iteration 92100 [389.80 sec]: learning rate : 0.000006 loss : 0.512244 +[22:26:36.177] iteration 92200 [476.72 sec]: learning rate : 0.000006 loss : 0.634836 +[22:28:03.192] iteration 92300 [563.74 sec]: learning rate : 0.000006 loss : 0.803690 +[22:29:30.211] iteration 92400 [650.76 sec]: learning rate : 0.000006 loss : 0.810050 +[22:30:57.155] iteration 92500 [737.70 sec]: learning rate : 0.000006 loss : 0.517123 +[22:32:24.137] iteration 92600 [824.68 sec]: learning rate : 0.000006 loss : 0.578440 +[22:33:51.112] iteration 92700 [911.66 sec]: learning rate : 0.000006 loss : 0.741489 +[22:35:18.029] iteration 92800 [998.58 sec]: learning rate : 0.000006 loss : 0.583890 +[22:36:45.001] iteration 92900 [1085.55 sec]: learning rate : 0.000006 loss : 0.415389 +[22:38:11.981] iteration 93000 [1172.53 sec]: learning rate : 0.000006 loss : 0.468858 +[22:39:38.947] iteration 93100 [1259.49 sec]: learning rate : 0.000006 loss : 0.793739 +[22:41:05.966] iteration 93200 [1346.51 sec]: learning rate : 0.000006 loss : 0.730892 +[22:42:32.881] iteration 93300 [1433.43 sec]: learning rate : 0.000006 loss : 0.419274 +[22:43:59.879] iteration 93400 [1520.42 sec]: learning rate : 0.000006 loss : 0.635874 +[22:45:26.833] iteration 93500 [1607.38 sec]: learning rate : 0.000006 loss : 0.642532 +[22:46:53.734] iteration 93600 [1694.28 sec]: learning rate : 0.000006 loss : 0.608941 +[22:48:20.740] iteration 93700 [1781.29 sec]: learning rate : 0.000006 loss : 0.620520 +[22:48:51.122] Epoch 44 Evaluation: +[22:49:41.388] average MSE: 0.054446835070848465 average PSNR: 28.27616756299156 average SSIM: 0.6040907261460614 +[22:50:38.217] iteration 93800 [56.77 sec]: learning rate : 0.000006 loss : 0.480586 +[22:52:05.068] iteration 93900 [143.62 sec]: learning rate : 0.000006 loss : 0.541282 +[22:53:31.991] iteration 94000 [230.54 sec]: learning rate : 0.000006 loss : 0.771015 +[22:54:58.912] iteration 94100 [317.46 sec]: learning rate : 0.000006 loss : 0.637244 +[22:56:25.802] iteration 94200 [404.35 sec]: learning rate : 0.000006 loss : 0.342625 +[22:57:52.749] iteration 94300 [491.30 sec]: learning rate : 0.000006 loss : 0.870344 +[22:59:19.675] iteration 94400 [578.22 sec]: learning rate : 0.000006 loss : 0.795202 +[23:00:46.676] iteration 94500 [665.23 sec]: learning rate : 0.000006 loss : 0.742162 +[23:02:13.681] iteration 94600 [752.23 sec]: learning rate : 0.000006 loss : 0.430493 +[23:03:40.638] iteration 94700 [839.19 sec]: learning rate : 0.000006 loss : 0.438087 +[23:05:07.604] iteration 94800 [926.15 sec]: learning rate : 0.000006 loss : 0.676210 +[23:06:34.606] iteration 94900 [1013.15 sec]: learning rate : 0.000006 loss : 0.552846 +[23:08:01.520] iteration 95000 [1100.07 sec]: learning rate : 0.000006 loss : 0.658782 +[23:09:28.522] iteration 95100 [1187.07 sec]: learning rate : 0.000006 loss : 0.647538 +[23:10:55.464] iteration 95200 [1274.01 sec]: learning rate : 0.000006 loss : 0.624249 +[23:12:22.401] iteration 95300 [1360.97 sec]: learning rate : 0.000006 loss : 0.893483 +[23:13:49.415] iteration 95400 [1447.96 sec]: learning rate : 0.000006 loss : 0.533846 +[23:15:16.362] iteration 95500 [1534.91 sec]: learning rate : 0.000006 loss : 0.453086 +[23:16:43.365] iteration 95600 [1621.91 sec]: learning rate : 0.000006 loss : 0.278753 +[23:18:10.303] iteration 95700 [1708.85 sec]: learning rate : 0.000006 loss : 0.574590 +[23:19:37.233] iteration 95800 [1795.78 sec]: learning rate : 0.000006 loss : 0.581671 +[23:19:52.846] Epoch 45 Evaluation: +[23:20:44.235] average MSE: 0.054444003850221634 average PSNR: 28.273443074397207 average SSIM: 0.6036210232377142 +[23:21:55.883] iteration 95900 [71.59 sec]: learning rate : 0.000006 loss : 0.650880 +[23:23:22.846] iteration 96000 [158.55 sec]: learning rate : 0.000006 loss : 0.441995 +[23:24:49.763] iteration 96100 [245.47 sec]: learning rate : 0.000006 loss : 0.368154 +[23:26:16.736] iteration 96200 [332.44 sec]: learning rate : 0.000006 loss : 0.585796 +[23:27:43.704] iteration 96300 [419.41 sec]: learning rate : 0.000006 loss : 0.582124 +[23:29:10.650] iteration 96400 [506.35 sec]: learning rate : 0.000006 loss : 0.591770 +[23:30:37.650] iteration 96500 [593.35 sec]: learning rate : 0.000006 loss : 0.599764 +[23:32:04.603] iteration 96600 [680.30 sec]: learning rate : 0.000006 loss : 0.515682 +[23:33:31.537] iteration 96700 [767.24 sec]: learning rate : 0.000006 loss : 1.032427 +[23:34:58.513] iteration 96800 [854.21 sec]: learning rate : 0.000006 loss : 0.394783 +[23:36:25.464] iteration 96900 [941.17 sec]: learning rate : 0.000006 loss : 0.721586 +[23:37:52.422] iteration 97000 [1028.12 sec]: learning rate : 0.000006 loss : 0.342356 +[23:39:19.422] iteration 97100 [1115.12 sec]: learning rate : 0.000006 loss : 0.429395 +[23:40:46.368] iteration 97200 [1202.09 sec]: learning rate : 0.000006 loss : 0.696658 +[23:42:13.412] iteration 97300 [1289.11 sec]: learning rate : 0.000006 loss : 0.696513 +[23:43:40.344] iteration 97400 [1376.05 sec]: learning rate : 0.000006 loss : 0.661987 +[23:45:07.351] iteration 97500 [1463.05 sec]: learning rate : 0.000006 loss : 0.707027 +[23:46:34.355] iteration 97600 [1550.06 sec]: learning rate : 0.000006 loss : 0.571873 +[23:48:01.268] iteration 97700 [1636.97 sec]: learning rate : 0.000006 loss : 0.486266 +[23:49:28.237] iteration 97800 [1723.94 sec]: learning rate : 0.000006 loss : 0.664544 +[23:50:55.174] iteration 97900 [1810.88 sec]: learning rate : 0.000006 loss : 0.519288 +[23:50:56.022] Epoch 46 Evaluation: +[23:51:45.944] average MSE: 0.05439363047480583 average PSNR: 28.28102222267758 average SSIM: 0.6037284678448768 +[23:53:12.192] iteration 98000 [86.19 sec]: learning rate : 0.000006 loss : 0.990284 +[23:54:39.194] iteration 98100 [173.19 sec]: learning rate : 0.000006 loss : 0.523376 +[23:56:06.149] iteration 98200 [260.14 sec]: learning rate : 0.000006 loss : 0.619781 +[23:57:33.056] iteration 98300 [347.05 sec]: learning rate : 0.000006 loss : 0.537919 +[23:58:59.991] iteration 98400 [433.98 sec]: learning rate : 0.000006 loss : 0.633666 +[00:00:26.988] iteration 98500 [520.98 sec]: learning rate : 0.000006 loss : 0.754822 +[00:01:53.918] iteration 98600 [607.91 sec]: learning rate : 0.000006 loss : 0.699686 +[00:03:20.892] iteration 98700 [694.89 sec]: learning rate : 0.000006 loss : 0.468678 +[00:04:47.803] iteration 98800 [781.80 sec]: learning rate : 0.000006 loss : 0.574144 +[00:06:14.796] iteration 98900 [868.79 sec]: learning rate : 0.000006 loss : 0.626043 +[00:07:41.754] iteration 99000 [955.75 sec]: learning rate : 0.000006 loss : 0.443504 +[00:09:08.698] iteration 99100 [1042.69 sec]: learning rate : 0.000006 loss : 0.659962 +[00:10:35.670] iteration 99200 [1129.66 sec]: learning rate : 0.000006 loss : 0.642841 +[00:12:02.641] iteration 99300 [1216.63 sec]: learning rate : 0.000006 loss : 0.664181 +[00:13:29.588] iteration 99400 [1303.58 sec]: learning rate : 0.000006 loss : 0.540989 +[00:14:56.570] iteration 99500 [1390.56 sec]: learning rate : 0.000006 loss : 0.773988 +[00:16:23.524] iteration 99600 [1477.52 sec]: learning rate : 0.000006 loss : 0.524634 +[00:17:50.552] iteration 99700 [1564.55 sec]: learning rate : 0.000006 loss : 0.449999 +[00:19:17.526] iteration 99800 [1651.52 sec]: learning rate : 0.000006 loss : 0.597606 +[00:20:44.413] iteration 99900 [1738.41 sec]: learning rate : 0.000006 loss : 0.382165 +[00:21:57.384] Epoch 47 Evaluation: +[00:22:49.834] average MSE: 0.054431408643722534 average PSNR: 28.27956611829288 average SSIM: 0.6037958477551008 +[00:23:03.978] iteration 100000 [14.08 sec]: learning rate : 0.000002 loss : 0.427649 +[00:23:04.147] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_100000.pth +[00:23:04.990] Epoch 48 Evaluation: +[00:23:56.734] average MSE: 0.054349660873413086 average PSNR: 28.287743118203952 average SSIM: 0.6041746145529535 +[00:23:57.005] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_100000.pth diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/log/events.out.tfevents.1752647703.GCRSANDBOX133 b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/log/events.out.tfevents.1752647703.GCRSANDBOX133 new file mode 100644 index 0000000000000000000000000000000000000000..b3349155dd109739d548d77af3f55042adfae243 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/log/events.out.tfevents.1752647703.GCRSANDBOX133 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9eabe50a57241eebbd815d58f41629524a88fab5738d11f7f453ebf5e2e87bbc +size 40 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_/best_checkpoint.pth b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_/best_checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..a9d1870790c8e96dbe06e39baa9843a47307c1be --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_/best_checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b5cd71ed2a077c3ff8ceec27f35125a8591422cb06698fa3ac5d9fcba721fdf +size 56614874 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_/log.txt b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..79ea7f56ebbce2860ef184838f5551bcffd37a17 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_/log.txt @@ -0,0 +1,1119 @@ +[05:58:34.499] Namespace(root_path='/home/v-qichen3/MRI_recon/data/fastmri', MRIDOWN='8X', low_field_SNR=0, phase='train', gpu='1', exp='FSMNet_fastmri_8x', max_iterations=100000, batch_size=4, base_lr=0.0001, seed=1337, resume=None, relation_consistency='False', clip_grad='True', norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='equispaced', CENTER_FRACTIONS=[0.04], ACCELERATIONS=[8], num_timesteps=5, image_size=320, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[06:00:04.038] iteration 100 [86.90 sec]: learning rate : 0.000100 loss : 0.713801 +[06:01:30.073] iteration 200 [172.94 sec]: learning rate : 0.000100 loss : 0.539555 +[06:02:56.097] iteration 300 [258.96 sec]: learning rate : 0.000100 loss : 0.687900 +[06:04:22.197] iteration 400 [345.06 sec]: learning rate : 0.000100 loss : 0.744821 +[06:05:48.330] iteration 500 [431.19 sec]: learning rate : 0.000100 loss : 0.674234 +[06:07:14.422] iteration 600 [517.28 sec]: learning rate : 0.000100 loss : 1.271755 +[06:08:40.569] iteration 700 [603.43 sec]: learning rate : 0.000100 loss : 0.667026 +[06:10:06.778] iteration 800 [689.64 sec]: learning rate : 0.000100 loss : 0.671631 +[06:11:32.910] iteration 900 [775.77 sec]: learning rate : 0.000100 loss : 0.912368 +[06:12:59.102] iteration 1000 [861.97 sec]: learning rate : 0.000100 loss : 0.627673 +[06:14:25.316] iteration 1100 [948.18 sec]: learning rate : 0.000100 loss : 0.771434 +[06:15:51.479] iteration 1200 [1034.34 sec]: learning rate : 0.000100 loss : 0.711366 +[06:17:17.753] iteration 1300 [1120.62 sec]: learning rate : 0.000100 loss : 0.460900 +[06:18:44.025] iteration 1400 [1206.89 sec]: learning rate : 0.000100 loss : 1.201037 +[06:20:10.255] iteration 1500 [1293.12 sec]: learning rate : 0.000100 loss : 0.746585 +[06:21:36.545] iteration 1600 [1379.41 sec]: learning rate : 0.000100 loss : 1.042526 +[06:23:02.825] iteration 1700 [1465.69 sec]: learning rate : 0.000100 loss : 0.888596 +[06:24:29.044] iteration 1800 [1551.91 sec]: learning rate : 0.000100 loss : 1.232691 +[06:25:55.318] iteration 1900 [1638.18 sec]: learning rate : 0.000100 loss : 0.642048 +[06:27:21.598] iteration 2000 [1724.46 sec]: learning rate : 0.000100 loss : 0.682200 +[06:28:33.121] Epoch 0 Evaluation: +[06:29:22.461] average MSE: 0.06779099255800247 average PSNR: 27.02808176500622 average SSIM: 0.5512911658675675 +[06:29:37.383] iteration 2100 [14.86 sec]: learning rate : 0.000100 loss : 0.750754 +[06:31:03.726] iteration 2200 [101.20 sec]: learning rate : 0.000100 loss : 0.659021 +[06:32:29.994] iteration 2300 [187.47 sec]: learning rate : 0.000100 loss : 0.830933 +[06:33:56.211] iteration 2400 [273.69 sec]: learning rate : 0.000100 loss : 0.966578 +[06:35:22.503] iteration 2500 [359.98 sec]: learning rate : 0.000100 loss : 0.899703 +[06:36:48.760] iteration 2600 [446.24 sec]: learning rate : 0.000100 loss : 0.617223 +[06:38:15.061] iteration 2700 [532.54 sec]: learning rate : 0.000100 loss : 1.100699 +[06:39:41.382] iteration 2800 [618.86 sec]: learning rate : 0.000100 loss : 0.952201 +[06:41:07.611] iteration 2900 [705.09 sec]: learning rate : 0.000100 loss : 1.069155 +[06:42:33.903] iteration 3000 [791.38 sec]: learning rate : 0.000100 loss : 0.857852 +[06:44:00.154] iteration 3100 [877.63 sec]: learning rate : 0.000100 loss : 0.955480 +[06:45:26.404] iteration 3200 [963.88 sec]: learning rate : 0.000100 loss : 0.445331 +[06:46:52.709] iteration 3300 [1050.19 sec]: learning rate : 0.000100 loss : 0.880401 +[06:48:19.010] iteration 3400 [1136.49 sec]: learning rate : 0.000100 loss : 0.733222 +[06:49:45.257] iteration 3500 [1222.73 sec]: learning rate : 0.000100 loss : 1.177640 +[06:51:11.556] iteration 3600 [1309.03 sec]: learning rate : 0.000100 loss : 0.608205 +[06:52:37.851] iteration 3700 [1395.33 sec]: learning rate : 0.000100 loss : 0.570668 +[06:54:04.094] iteration 3800 [1481.57 sec]: learning rate : 0.000100 loss : 0.443695 +[06:55:30.389] iteration 3900 [1567.87 sec]: learning rate : 0.000100 loss : 0.599232 +[06:56:56.701] iteration 4000 [1654.18 sec]: learning rate : 0.000100 loss : 0.642613 +[06:58:22.945] iteration 4100 [1740.42 sec]: learning rate : 0.000100 loss : 0.769693 +[06:59:19.897] Epoch 1 Evaluation: +[07:00:11.385] average MSE: 0.06673482060432434 average PSNR: 27.102565131712982 average SSIM: 0.5617522393473829 +[07:00:40.926] iteration 4200 [29.48 sec]: learning rate : 0.000100 loss : 1.395733 +[07:02:07.140] iteration 4300 [115.69 sec]: learning rate : 0.000100 loss : 0.676689 +[07:03:33.465] iteration 4400 [202.02 sec]: learning rate : 0.000100 loss : 0.518412 +[07:04:59.747] iteration 4500 [288.30 sec]: learning rate : 0.000100 loss : 0.551814 +[07:06:25.994] iteration 4600 [374.55 sec]: learning rate : 0.000100 loss : 0.920456 +[07:07:52.303] iteration 4700 [460.86 sec]: learning rate : 0.000100 loss : 0.368831 +[07:09:18.609] iteration 4800 [547.16 sec]: learning rate : 0.000100 loss : 0.717815 +[07:10:44.856] iteration 4900 [633.41 sec]: learning rate : 0.000100 loss : 0.487726 +[07:12:11.162] iteration 5000 [719.71 sec]: learning rate : 0.000100 loss : 0.556725 +[07:13:37.436] iteration 5100 [805.99 sec]: learning rate : 0.000100 loss : 0.527806 +[07:15:03.687] iteration 5200 [892.24 sec]: learning rate : 0.000100 loss : 0.599690 +[07:16:29.963] iteration 5300 [978.52 sec]: learning rate : 0.000100 loss : 0.708511 +[07:17:56.220] iteration 5400 [1064.77 sec]: learning rate : 0.000100 loss : 0.637890 +[07:19:22.442] iteration 5500 [1150.99 sec]: learning rate : 0.000100 loss : 1.257058 +[07:20:48.775] iteration 5600 [1237.33 sec]: learning rate : 0.000100 loss : 0.892468 +[07:22:15.036] iteration 5700 [1323.59 sec]: learning rate : 0.000100 loss : 0.797089 +[07:23:41.339] iteration 5800 [1409.89 sec]: learning rate : 0.000100 loss : 0.804234 +[07:25:07.659] iteration 5900 [1496.21 sec]: learning rate : 0.000100 loss : 0.556186 +[07:26:33.915] iteration 6000 [1582.47 sec]: learning rate : 0.000100 loss : 0.485234 +[07:28:00.207] iteration 6100 [1668.76 sec]: learning rate : 0.000100 loss : 0.549961 +[07:29:26.494] iteration 6200 [1755.05 sec]: learning rate : 0.000100 loss : 0.537607 +[07:30:08.734] Epoch 2 Evaluation: +[07:30:58.093] average MSE: 0.06311102956533432 average PSNR: 27.39658342039117 average SSIM: 0.5701182122047868 +[07:31:42.315] iteration 6300 [44.16 sec]: learning rate : 0.000100 loss : 0.783106 +[07:33:08.627] iteration 6400 [130.47 sec]: learning rate : 0.000100 loss : 0.744705 +[07:34:34.882] iteration 6500 [216.73 sec]: learning rate : 0.000100 loss : 0.501814 +[07:36:01.205] iteration 6600 [303.05 sec]: learning rate : 0.000100 loss : 0.660887 +[07:37:27.494] iteration 6700 [389.34 sec]: learning rate : 0.000100 loss : 0.842585 +[07:38:53.762] iteration 6800 [475.61 sec]: learning rate : 0.000100 loss : 0.953897 +[07:40:20.057] iteration 6900 [561.90 sec]: learning rate : 0.000100 loss : 0.868530 +[07:41:46.349] iteration 7000 [648.19 sec]: learning rate : 0.000100 loss : 0.640019 +[07:43:12.639] iteration 7100 [734.48 sec]: learning rate : 0.000100 loss : 0.569970 +[07:44:38.963] iteration 7200 [820.81 sec]: learning rate : 0.000100 loss : 0.444702 +[07:46:05.334] iteration 7300 [907.18 sec]: learning rate : 0.000100 loss : 0.924769 +[07:47:31.587] iteration 7400 [993.43 sec]: learning rate : 0.000100 loss : 1.152311 +[07:48:57.876] iteration 7500 [1079.72 sec]: learning rate : 0.000100 loss : 0.486305 +[07:50:24.105] iteration 7600 [1165.95 sec]: learning rate : 0.000100 loss : 1.159941 +[07:51:50.429] iteration 7700 [1252.27 sec]: learning rate : 0.000100 loss : 0.439353 +[07:53:16.708] iteration 7800 [1338.55 sec]: learning rate : 0.000100 loss : 0.823788 +[07:54:42.911] iteration 7900 [1424.76 sec]: learning rate : 0.000100 loss : 0.969220 +[07:56:09.157] iteration 8000 [1511.00 sec]: learning rate : 0.000100 loss : 0.604276 +[07:57:35.426] iteration 8100 [1597.27 sec]: learning rate : 0.000100 loss : 0.716725 +[07:59:01.626] iteration 8200 [1683.47 sec]: learning rate : 0.000100 loss : 0.430251 +[08:00:27.891] iteration 8300 [1769.74 sec]: learning rate : 0.000100 loss : 0.401854 +[08:00:55.484] Epoch 3 Evaluation: +[08:01:45.219] average MSE: 0.061659764498472214 average PSNR: 27.517912763930592 average SSIM: 0.576192170270393 +[08:02:44.240] iteration 8400 [58.96 sec]: learning rate : 0.000100 loss : 0.703451 +[08:04:10.463] iteration 8500 [145.18 sec]: learning rate : 0.000100 loss : 0.700282 +[08:05:36.734] iteration 8600 [231.45 sec]: learning rate : 0.000100 loss : 0.700056 +[08:07:02.990] iteration 8700 [317.71 sec]: learning rate : 0.000100 loss : 0.979556 +[08:08:29.273] iteration 8800 [403.99 sec]: learning rate : 0.000100 loss : 0.555990 +[08:09:55.597] iteration 8900 [490.32 sec]: learning rate : 0.000100 loss : 0.874196 +[08:11:21.891] iteration 9000 [576.61 sec]: learning rate : 0.000100 loss : 0.421876 +[08:12:48.222] iteration 9100 [662.94 sec]: learning rate : 0.000100 loss : 0.637884 +[08:14:14.576] iteration 9200 [749.30 sec]: learning rate : 0.000100 loss : 0.617532 +[08:15:40.802] iteration 9300 [835.52 sec]: learning rate : 0.000100 loss : 0.624824 +[08:17:07.122] iteration 9400 [921.84 sec]: learning rate : 0.000100 loss : 0.602544 +[08:18:33.411] iteration 9500 [1008.13 sec]: learning rate : 0.000100 loss : 0.365202 +[08:19:59.646] iteration 9600 [1094.37 sec]: learning rate : 0.000100 loss : 0.452786 +[08:21:25.927] iteration 9700 [1180.65 sec]: learning rate : 0.000100 loss : 0.346873 +[08:22:52.157] iteration 9800 [1266.88 sec]: learning rate : 0.000100 loss : 0.935819 +[08:24:18.398] iteration 9900 [1353.12 sec]: learning rate : 0.000100 loss : 0.435306 +[08:25:44.680] iteration 10000 [1439.40 sec]: learning rate : 0.000100 loss : 0.514612 +[08:27:10.924] iteration 10100 [1525.64 sec]: learning rate : 0.000100 loss : 0.656352 +[08:28:37.125] iteration 10200 [1611.84 sec]: learning rate : 0.000100 loss : 0.476490 +[08:30:03.412] iteration 10300 [1698.13 sec]: learning rate : 0.000100 loss : 0.668886 +[08:31:29.632] iteration 10400 [1784.35 sec]: learning rate : 0.000100 loss : 0.523053 +[08:31:42.532] Epoch 4 Evaluation: +[08:32:31.906] average MSE: 0.06149167940020561 average PSNR: 27.529866096730274 average SSIM: 0.5783549927628137 +[08:33:45.419] iteration 10500 [73.45 sec]: learning rate : 0.000100 loss : 0.641283 +[08:35:11.771] iteration 10600 [159.80 sec]: learning rate : 0.000100 loss : 0.541461 +[08:36:38.115] iteration 10700 [246.15 sec]: learning rate : 0.000100 loss : 0.888151 +[08:38:04.413] iteration 10800 [332.44 sec]: learning rate : 0.000100 loss : 0.462070 +[08:39:30.743] iteration 10900 [418.77 sec]: learning rate : 0.000100 loss : 0.852710 +[08:40:57.021] iteration 11000 [505.05 sec]: learning rate : 0.000100 loss : 0.443757 +[08:42:23.244] iteration 11100 [591.27 sec]: learning rate : 0.000100 loss : 0.695714 +[08:43:49.532] iteration 11200 [677.56 sec]: learning rate : 0.000100 loss : 0.610102 +[08:45:15.793] iteration 11300 [763.82 sec]: learning rate : 0.000100 loss : 0.682788 +[08:46:42.005] iteration 11400 [850.04 sec]: learning rate : 0.000100 loss : 0.751402 +[08:48:08.232] iteration 11500 [936.26 sec]: learning rate : 0.000100 loss : 0.401495 +[08:49:34.420] iteration 11600 [1022.45 sec]: learning rate : 0.000100 loss : 1.018123 +[08:51:00.682] iteration 11700 [1108.71 sec]: learning rate : 0.000100 loss : 0.826384 +[08:52:26.977] iteration 11800 [1195.01 sec]: learning rate : 0.000100 loss : 0.672444 +[08:53:53.251] iteration 11900 [1281.28 sec]: learning rate : 0.000100 loss : 0.473160 +[08:55:19.580] iteration 12000 [1367.61 sec]: learning rate : 0.000100 loss : 0.673737 +[08:56:45.820] iteration 12100 [1453.85 sec]: learning rate : 0.000100 loss : 0.378131 +[08:58:12.056] iteration 12200 [1540.09 sec]: learning rate : 0.000100 loss : 0.602658 +[08:59:38.343] iteration 12300 [1626.37 sec]: learning rate : 0.000100 loss : 0.975415 +[09:01:04.540] iteration 12400 [1712.57 sec]: learning rate : 0.000100 loss : 0.812931 +[09:02:29.026] Epoch 5 Evaluation: +[09:03:20.975] average MSE: 0.060833174735307693 average PSNR: 27.612369462937693 average SSIM: 0.5815955147805237 +[09:03:22.966] iteration 12500 [1.93 sec]: learning rate : 0.000100 loss : 1.035919 +[09:04:49.159] iteration 12600 [88.12 sec]: learning rate : 0.000100 loss : 0.546766 +[09:06:15.465] iteration 12700 [174.43 sec]: learning rate : 0.000100 loss : 0.555355 +[09:07:41.718] iteration 12800 [260.68 sec]: learning rate : 0.000100 loss : 0.589223 +[09:09:07.927] iteration 12900 [346.89 sec]: learning rate : 0.000100 loss : 0.718296 +[09:10:34.217] iteration 13000 [433.18 sec]: learning rate : 0.000100 loss : 0.576205 +[09:12:00.481] iteration 13100 [519.44 sec]: learning rate : 0.000100 loss : 0.534499 +[09:13:26.694] iteration 13200 [605.66 sec]: learning rate : 0.000100 loss : 0.555998 +[09:14:53.018] iteration 13300 [691.98 sec]: learning rate : 0.000100 loss : 0.626309 +[09:16:19.341] iteration 13400 [778.30 sec]: learning rate : 0.000100 loss : 0.370466 +[09:17:45.599] iteration 13500 [864.56 sec]: learning rate : 0.000100 loss : 0.543036 +[09:19:11.931] iteration 13600 [950.89 sec]: learning rate : 0.000100 loss : 0.604181 +[09:20:38.254] iteration 13700 [1037.22 sec]: learning rate : 0.000100 loss : 0.579700 +[09:22:04.520] iteration 13800 [1123.48 sec]: learning rate : 0.000100 loss : 0.677822 +[09:23:30.820] iteration 13900 [1209.78 sec]: learning rate : 0.000100 loss : 0.833256 +[09:24:57.077] iteration 14000 [1296.04 sec]: learning rate : 0.000100 loss : 0.772199 +[09:26:23.300] iteration 14100 [1382.26 sec]: learning rate : 0.000100 loss : 0.802626 +[09:27:49.616] iteration 14200 [1468.58 sec]: learning rate : 0.000100 loss : 1.034420 +[09:29:15.872] iteration 14300 [1554.84 sec]: learning rate : 0.000100 loss : 0.476870 +[09:30:42.200] iteration 14400 [1641.16 sec]: learning rate : 0.000100 loss : 0.645662 +[09:32:08.504] iteration 14500 [1727.47 sec]: learning rate : 0.000100 loss : 0.550450 +[09:33:18.349] Epoch 6 Evaluation: +[09:34:09.834] average MSE: 0.05932967737317085 average PSNR: 27.7462835765415 average SSIM: 0.5850473755463758 +[09:34:26.439] iteration 14600 [16.54 sec]: learning rate : 0.000100 loss : 0.506978 +[09:35:52.778] iteration 14700 [102.88 sec]: learning rate : 0.000100 loss : 0.481043 +[09:37:19.066] iteration 14800 [189.17 sec]: learning rate : 0.000100 loss : 0.592735 +[09:38:45.302] iteration 14900 [275.41 sec]: learning rate : 0.000100 loss : 0.582268 +[09:40:11.602] iteration 15000 [361.71 sec]: learning rate : 0.000100 loss : 0.708353 +[09:41:37.913] iteration 15100 [448.02 sec]: learning rate : 0.000100 loss : 0.619494 +[09:43:04.174] iteration 15200 [534.28 sec]: learning rate : 0.000100 loss : 0.852084 +[09:44:30.445] iteration 15300 [620.55 sec]: learning rate : 0.000100 loss : 0.874792 +[09:45:56.706] iteration 15400 [706.81 sec]: learning rate : 0.000100 loss : 0.801275 +[09:47:22.944] iteration 15500 [793.05 sec]: learning rate : 0.000100 loss : 0.905272 +[09:48:49.239] iteration 15600 [879.34 sec]: learning rate : 0.000100 loss : 0.882472 +[09:50:15.489] iteration 15700 [965.59 sec]: learning rate : 0.000100 loss : 0.400995 +[09:51:41.810] iteration 15800 [1051.91 sec]: learning rate : 0.000100 loss : 0.855925 +[09:53:08.104] iteration 15900 [1138.21 sec]: learning rate : 0.000100 loss : 0.559926 +[09:54:34.351] iteration 16000 [1224.45 sec]: learning rate : 0.000100 loss : 0.729166 +[09:56:00.653] iteration 16100 [1310.76 sec]: learning rate : 0.000100 loss : 0.380401 +[09:57:26.950] iteration 16200 [1397.05 sec]: learning rate : 0.000100 loss : 0.425018 +[09:58:53.201] iteration 16300 [1483.31 sec]: learning rate : 0.000100 loss : 0.572393 +[10:00:19.494] iteration 16400 [1569.60 sec]: learning rate : 0.000100 loss : 0.538106 +[10:01:45.786] iteration 16500 [1655.89 sec]: learning rate : 0.000100 loss : 0.349107 +[10:03:12.016] iteration 16600 [1742.12 sec]: learning rate : 0.000100 loss : 0.823004 +[10:04:07.250] Epoch 7 Evaluation: +[10:04:56.893] average MSE: 0.05865757167339325 average PSNR: 27.80455300215071 average SSIM: 0.586866012001192 +[10:05:28.133] iteration 16700 [31.18 sec]: learning rate : 0.000100 loss : 0.769167 +[10:06:54.400] iteration 16800 [117.51 sec]: learning rate : 0.000100 loss : 0.473956 +[10:08:20.572] iteration 16900 [203.62 sec]: learning rate : 0.000100 loss : 0.497283 +[10:09:46.811] iteration 17000 [289.86 sec]: learning rate : 0.000100 loss : 0.676878 +[10:11:13.012] iteration 17100 [376.06 sec]: learning rate : 0.000100 loss : 0.561200 +[10:12:39.230] iteration 17200 [462.27 sec]: learning rate : 0.000100 loss : 0.317980 +[10:14:05.505] iteration 17300 [548.55 sec]: learning rate : 0.000100 loss : 0.688612 +[10:15:31.756] iteration 17400 [634.80 sec]: learning rate : 0.000100 loss : 0.521733 +[10:16:58.026] iteration 17500 [721.07 sec]: learning rate : 0.000100 loss : 0.644899 +[10:18:24.329] iteration 17600 [807.37 sec]: learning rate : 0.000100 loss : 0.478313 +[10:19:50.569] iteration 17700 [893.61 sec]: learning rate : 0.000100 loss : 0.734991 +[10:21:16.881] iteration 17800 [979.93 sec]: learning rate : 0.000100 loss : 0.550922 +[10:22:43.211] iteration 17900 [1066.26 sec]: learning rate : 0.000100 loss : 0.714702 +[10:24:09.483] iteration 18000 [1152.53 sec]: learning rate : 0.000100 loss : 0.931386 +[10:25:35.832] iteration 18100 [1238.88 sec]: learning rate : 0.000100 loss : 0.553852 +[10:27:02.059] iteration 18200 [1325.10 sec]: learning rate : 0.000100 loss : 0.611231 +[10:28:28.283] iteration 18300 [1411.33 sec]: learning rate : 0.000100 loss : 0.494235 +[10:29:54.539] iteration 18400 [1497.58 sec]: learning rate : 0.000100 loss : 0.444141 +[10:31:20.738] iteration 18500 [1583.78 sec]: learning rate : 0.000100 loss : 0.491969 +[10:32:47.018] iteration 18600 [1670.06 sec]: learning rate : 0.000100 loss : 0.492299 +[10:34:13.223] iteration 18700 [1756.27 sec]: learning rate : 0.000100 loss : 0.377273 +[10:34:53.771] Epoch 8 Evaluation: +[10:35:43.699] average MSE: 0.05896114185452461 average PSNR: 27.797308839766668 average SSIM: 0.5894029730864415 +[10:36:29.630] iteration 18800 [45.87 sec]: learning rate : 0.000100 loss : 1.054159 +[10:37:55.887] iteration 18900 [132.13 sec]: learning rate : 0.000100 loss : 0.481934 +[10:39:22.097] iteration 19000 [218.34 sec]: learning rate : 0.000100 loss : 0.787408 +[10:40:48.418] iteration 19100 [304.66 sec]: learning rate : 0.000100 loss : 0.786007 +[10:42:14.761] iteration 19200 [391.00 sec]: learning rate : 0.000100 loss : 0.614838 +[10:43:41.012] iteration 19300 [477.25 sec]: learning rate : 0.000100 loss : 0.761721 +[10:45:07.318] iteration 19400 [563.56 sec]: learning rate : 0.000100 loss : 1.638125 +[10:46:33.502] iteration 19500 [649.74 sec]: learning rate : 0.000100 loss : 0.578186 +[10:47:59.741] iteration 19600 [735.98 sec]: learning rate : 0.000100 loss : 0.530939 +[10:49:25.967] iteration 19700 [822.21 sec]: learning rate : 0.000100 loss : 0.481359 +[10:50:52.128] iteration 19800 [908.37 sec]: learning rate : 0.000100 loss : 0.423448 +[10:52:18.363] iteration 19900 [994.60 sec]: learning rate : 0.000100 loss : 0.537403 +[10:53:44.596] iteration 20000 [1080.83 sec]: learning rate : 0.000025 loss : 0.454308 +[10:53:44.752] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_20000.pth +[10:55:10.982] iteration 20100 [1167.22 sec]: learning rate : 0.000050 loss : 0.643034 +[10:56:37.259] iteration 20200 [1253.50 sec]: learning rate : 0.000050 loss : 0.821519 +[10:58:03.560] iteration 20300 [1339.80 sec]: learning rate : 0.000050 loss : 0.500219 +[10:59:29.774] iteration 20400 [1426.01 sec]: learning rate : 0.000050 loss : 0.745510 +[11:00:56.051] iteration 20500 [1512.29 sec]: learning rate : 0.000050 loss : 0.658442 +[11:02:22.320] iteration 20600 [1598.56 sec]: learning rate : 0.000050 loss : 0.472217 +[11:03:48.544] iteration 20700 [1684.78 sec]: learning rate : 0.000050 loss : 0.544569 +[11:05:14.818] iteration 20800 [1771.06 sec]: learning rate : 0.000050 loss : 0.670821 +[11:05:40.656] Epoch 9 Evaluation: +[11:06:29.947] average MSE: 0.057913001626729965 average PSNR: 27.905017482708857 average SSIM: 0.5924856542098208 +[11:07:30.679] iteration 20900 [60.67 sec]: learning rate : 0.000050 loss : 0.577467 +[11:08:56.914] iteration 21000 [146.91 sec]: learning rate : 0.000050 loss : 0.833598 +[11:10:23.227] iteration 21100 [233.22 sec]: learning rate : 0.000050 loss : 0.385963 +[11:11:49.524] iteration 21200 [319.51 sec]: learning rate : 0.000050 loss : 0.639634 +[11:13:15.764] iteration 21300 [405.75 sec]: learning rate : 0.000050 loss : 0.844080 +[11:14:42.128] iteration 21400 [492.12 sec]: learning rate : 0.000050 loss : 0.495613 +[11:16:08.473] iteration 21500 [578.46 sec]: learning rate : 0.000050 loss : 0.581671 +[11:17:34.725] iteration 21600 [664.71 sec]: learning rate : 0.000050 loss : 0.672084 +[11:19:01.051] iteration 21700 [751.04 sec]: learning rate : 0.000050 loss : 0.546827 +[11:20:27.308] iteration 21800 [837.30 sec]: learning rate : 0.000050 loss : 0.451710 +[11:21:53.598] iteration 21900 [923.59 sec]: learning rate : 0.000050 loss : 0.614563 +[11:23:19.950] iteration 22000 [1009.94 sec]: learning rate : 0.000050 loss : 0.701274 +[11:24:46.234] iteration 22100 [1096.22 sec]: learning rate : 0.000050 loss : 1.060020 +[11:26:12.568] iteration 22200 [1182.56 sec]: learning rate : 0.000050 loss : 0.647921 +[11:27:38.898] iteration 22300 [1268.89 sec]: learning rate : 0.000050 loss : 0.815540 +[11:29:05.195] iteration 22400 [1355.19 sec]: learning rate : 0.000050 loss : 0.897583 +[11:30:31.476] iteration 22500 [1441.47 sec]: learning rate : 0.000050 loss : 0.835756 +[11:31:57.795] iteration 22600 [1527.79 sec]: learning rate : 0.000050 loss : 0.396342 +[11:33:24.067] iteration 22700 [1614.06 sec]: learning rate : 0.000050 loss : 0.558727 +[11:34:50.343] iteration 22800 [1700.33 sec]: learning rate : 0.000050 loss : 1.043089 +[11:36:16.567] iteration 22900 [1786.56 sec]: learning rate : 0.000050 loss : 0.432188 +[11:36:27.760] Epoch 10 Evaluation: +[11:37:18.791] average MSE: 0.05786362290382385 average PSNR: 27.907846426904836 average SSIM: 0.5934583142822976 +[11:38:34.059] iteration 23000 [75.21 sec]: learning rate : 0.000050 loss : 0.573912 +[11:40:00.465] iteration 23100 [161.61 sec]: learning rate : 0.000050 loss : 0.735087 +[11:41:26.735] iteration 23200 [247.88 sec]: learning rate : 0.000050 loss : 0.684241 +[11:42:53.071] iteration 23300 [334.22 sec]: learning rate : 0.000050 loss : 0.612501 +[11:44:19.441] iteration 23400 [420.59 sec]: learning rate : 0.000050 loss : 0.452520 +[11:45:45.733] iteration 23500 [506.88 sec]: learning rate : 0.000050 loss : 0.657426 +[11:47:12.050] iteration 23600 [593.20 sec]: learning rate : 0.000050 loss : 0.530132 +[11:48:38.373] iteration 23700 [679.52 sec]: learning rate : 0.000050 loss : 0.637388 +[11:50:04.657] iteration 23800 [765.80 sec]: learning rate : 0.000050 loss : 0.629401 +[11:51:30.984] iteration 23900 [852.13 sec]: learning rate : 0.000050 loss : 0.515311 +[11:52:57.315] iteration 24000 [938.46 sec]: learning rate : 0.000050 loss : 0.785522 +[11:54:23.550] iteration 24100 [1024.70 sec]: learning rate : 0.000050 loss : 0.738359 +[11:55:49.881] iteration 24200 [1111.03 sec]: learning rate : 0.000050 loss : 0.823418 +[11:57:16.204] iteration 24300 [1197.35 sec]: learning rate : 0.000050 loss : 0.696995 +[11:58:42.555] iteration 24400 [1283.70 sec]: learning rate : 0.000050 loss : 0.448009 +[12:00:08.898] iteration 24500 [1370.05 sec]: learning rate : 0.000050 loss : 0.632286 +[12:01:35.181] iteration 24600 [1456.33 sec]: learning rate : 0.000050 loss : 0.491077 +[12:03:01.524] iteration 24700 [1542.67 sec]: learning rate : 0.000050 loss : 1.199500 +[12:04:27.868] iteration 24800 [1629.01 sec]: learning rate : 0.000050 loss : 0.862487 +[12:05:54.152] iteration 24900 [1715.30 sec]: learning rate : 0.000050 loss : 0.585472 +[12:07:17.022] Epoch 11 Evaluation: +[12:08:07.687] average MSE: 0.05820254981517792 average PSNR: 27.9057031594872 average SSIM: 0.5944777760384121 +[12:08:11.386] iteration 25000 [3.64 sec]: learning rate : 0.000050 loss : 0.653967 +[12:09:37.738] iteration 25100 [89.99 sec]: learning rate : 0.000050 loss : 0.686378 +[12:11:03.954] iteration 25200 [176.21 sec]: learning rate : 0.000050 loss : 0.609439 +[12:12:30.228] iteration 25300 [262.48 sec]: learning rate : 0.000050 loss : 0.578438 +[12:13:56.556] iteration 25400 [348.81 sec]: learning rate : 0.000050 loss : 0.518495 +[12:15:22.799] iteration 25500 [435.05 sec]: learning rate : 0.000050 loss : 0.417631 +[12:16:49.120] iteration 25600 [521.37 sec]: learning rate : 0.000050 loss : 0.498927 +[12:18:15.422] iteration 25700 [607.67 sec]: learning rate : 0.000050 loss : 0.681544 +[12:19:41.762] iteration 25800 [694.01 sec]: learning rate : 0.000050 loss : 0.468233 +[12:21:08.047] iteration 25900 [780.30 sec]: learning rate : 0.000050 loss : 0.691197 +[12:22:34.282] iteration 26000 [866.53 sec]: learning rate : 0.000050 loss : 0.591168 +[12:24:00.652] iteration 26100 [952.90 sec]: learning rate : 0.000050 loss : 0.574994 +[12:25:26.970] iteration 26200 [1039.22 sec]: learning rate : 0.000050 loss : 0.470356 +[12:26:53.212] iteration 26300 [1125.46 sec]: learning rate : 0.000050 loss : 0.900203 +[12:28:19.505] iteration 26400 [1211.76 sec]: learning rate : 0.000050 loss : 0.585052 +[12:29:45.783] iteration 26500 [1298.03 sec]: learning rate : 0.000050 loss : 0.697643 +[12:31:12.146] iteration 26600 [1384.40 sec]: learning rate : 0.000050 loss : 1.049657 +[12:32:38.428] iteration 26700 [1470.68 sec]: learning rate : 0.000050 loss : 1.098018 +[12:34:04.676] iteration 26800 [1556.93 sec]: learning rate : 0.000050 loss : 0.633174 +[12:35:31.042] iteration 26900 [1643.29 sec]: learning rate : 0.000050 loss : 0.530536 +[12:36:57.346] iteration 27000 [1729.60 sec]: learning rate : 0.000050 loss : 0.351493 +[12:38:05.545] Epoch 12 Evaluation: +[12:38:54.778] average MSE: 0.05871913954615593 average PSNR: 27.869475160219267 average SSIM: 0.5947281083802304 +[12:39:13.118] iteration 27100 [18.28 sec]: learning rate : 0.000050 loss : 0.640630 +[12:40:39.474] iteration 27200 [104.63 sec]: learning rate : 0.000050 loss : 0.674785 +[12:42:05.771] iteration 27300 [190.93 sec]: learning rate : 0.000050 loss : 0.927779 +[12:43:32.123] iteration 27400 [277.28 sec]: learning rate : 0.000050 loss : 0.419984 +[12:44:58.490] iteration 27500 [363.65 sec]: learning rate : 0.000050 loss : 0.468889 +[12:46:24.835] iteration 27600 [449.99 sec]: learning rate : 0.000050 loss : 0.632805 +[12:47:51.160] iteration 27700 [536.32 sec]: learning rate : 0.000050 loss : 0.768526 +[12:49:17.469] iteration 27800 [622.69 sec]: learning rate : 0.000050 loss : 0.643640 +[12:50:43.731] iteration 27900 [708.89 sec]: learning rate : 0.000050 loss : 0.567725 +[12:52:10.095] iteration 28000 [795.25 sec]: learning rate : 0.000050 loss : 0.697371 +[12:53:36.409] iteration 28100 [881.57 sec]: learning rate : 0.000050 loss : 0.496827 +[12:55:02.795] iteration 28200 [967.95 sec]: learning rate : 0.000050 loss : 0.711058 +[12:56:29.185] iteration 28300 [1054.34 sec]: learning rate : 0.000050 loss : 0.603857 +[12:57:55.492] iteration 28400 [1140.65 sec]: learning rate : 0.000050 loss : 0.547136 +[12:59:21.868] iteration 28500 [1227.03 sec]: learning rate : 0.000050 loss : 0.673883 +[13:00:48.229] iteration 28600 [1313.39 sec]: learning rate : 0.000050 loss : 0.692519 +[13:02:14.533] iteration 28700 [1399.69 sec]: learning rate : 0.000050 loss : 0.507544 +[13:03:40.894] iteration 28800 [1486.05 sec]: learning rate : 0.000050 loss : 0.982437 +[13:05:07.234] iteration 28900 [1572.39 sec]: learning rate : 0.000050 loss : 1.054573 +[13:06:33.530] iteration 29000 [1658.69 sec]: learning rate : 0.000050 loss : 0.740069 +[13:07:59.861] iteration 29100 [1745.02 sec]: learning rate : 0.000050 loss : 0.820807 +[13:08:53.325] Epoch 13 Evaluation: +[13:09:42.614] average MSE: 0.05808725953102112 average PSNR: 27.91516479142096 average SSIM: 0.5956752937568905 +[13:10:15.748] iteration 29200 [33.07 sec]: learning rate : 0.000050 loss : 0.371228 +[13:11:42.042] iteration 29300 [119.37 sec]: learning rate : 0.000050 loss : 0.643244 +[13:13:08.372] iteration 29400 [205.70 sec]: learning rate : 0.000050 loss : 0.522669 +[13:14:34.660] iteration 29500 [291.98 sec]: learning rate : 0.000050 loss : 0.522399 +[13:16:01.017] iteration 29600 [378.34 sec]: learning rate : 0.000050 loss : 0.701059 +[13:17:27.389] iteration 29700 [464.71 sec]: learning rate : 0.000050 loss : 1.019208 +[13:18:53.693] iteration 29800 [551.02 sec]: learning rate : 0.000050 loss : 0.559866 +[13:20:20.077] iteration 29900 [637.40 sec]: learning rate : 0.000050 loss : 0.568384 +[13:21:46.447] iteration 30000 [723.77 sec]: learning rate : 0.000050 loss : 0.540130 +[13:23:12.760] iteration 30100 [810.08 sec]: learning rate : 0.000050 loss : 0.693469 +[13:24:39.104] iteration 30200 [896.43 sec]: learning rate : 0.000050 loss : 0.645467 +[13:26:05.431] iteration 30300 [982.75 sec]: learning rate : 0.000050 loss : 0.713284 +[13:27:31.711] iteration 30400 [1069.03 sec]: learning rate : 0.000050 loss : 0.891398 +[13:28:58.081] iteration 30500 [1155.40 sec]: learning rate : 0.000050 loss : 0.561638 +[13:30:24.434] iteration 30600 [1241.76 sec]: learning rate : 0.000050 loss : 0.722563 +[13:31:50.741] iteration 30700 [1328.06 sec]: learning rate : 0.000050 loss : 0.549373 +[13:33:17.092] iteration 30800 [1414.41 sec]: learning rate : 0.000050 loss : 0.904834 +[13:34:43.400] iteration 30900 [1500.72 sec]: learning rate : 0.000050 loss : 0.723159 +[13:36:09.774] iteration 31000 [1587.10 sec]: learning rate : 0.000050 loss : 0.651578 +[13:37:36.128] iteration 31100 [1673.45 sec]: learning rate : 0.000050 loss : 0.472426 +[13:39:02.430] iteration 31200 [1759.75 sec]: learning rate : 0.000050 loss : 0.599370 +[13:39:41.299] Epoch 14 Evaluation: +[13:40:30.874] average MSE: 0.057692527770996094 average PSNR: 27.958999625906177 average SSIM: 0.5952897625212644 +[13:41:18.519] iteration 31300 [47.58 sec]: learning rate : 0.000050 loss : 0.678323 +[13:42:44.877] iteration 31400 [133.94 sec]: learning rate : 0.000050 loss : 0.513713 +[13:44:11.141] iteration 31500 [220.20 sec]: learning rate : 0.000050 loss : 0.750180 +[13:45:37.533] iteration 31600 [306.60 sec]: learning rate : 0.000050 loss : 0.852661 +[13:47:03.850] iteration 31700 [392.91 sec]: learning rate : 0.000050 loss : 0.461039 +[13:48:30.141] iteration 31800 [479.20 sec]: learning rate : 0.000050 loss : 0.512907 +[13:49:56.451] iteration 31900 [565.51 sec]: learning rate : 0.000050 loss : 0.703911 +[13:51:22.676] iteration 32000 [651.74 sec]: learning rate : 0.000050 loss : 0.899406 +[13:52:48.944] iteration 32100 [738.01 sec]: learning rate : 0.000050 loss : 0.590397 +[13:54:15.252] iteration 32200 [824.31 sec]: learning rate : 0.000050 loss : 0.537748 +[13:55:41.503] iteration 32300 [910.57 sec]: learning rate : 0.000050 loss : 0.432986 +[13:57:07.850] iteration 32400 [996.91 sec]: learning rate : 0.000050 loss : 0.549956 +[13:58:34.097] iteration 32500 [1083.16 sec]: learning rate : 0.000050 loss : 0.305137 +[14:00:00.415] iteration 32600 [1169.48 sec]: learning rate : 0.000050 loss : 0.766955 +[14:01:26.791] iteration 32700 [1255.85 sec]: learning rate : 0.000050 loss : 0.527446 +[14:02:53.155] iteration 32800 [1342.22 sec]: learning rate : 0.000050 loss : 0.634685 +[14:04:19.551] iteration 32900 [1428.61 sec]: learning rate : 0.000050 loss : 0.695133 +[14:05:45.931] iteration 33000 [1514.99 sec]: learning rate : 0.000050 loss : 0.681834 +[14:07:12.244] iteration 33100 [1601.31 sec]: learning rate : 0.000050 loss : 0.501289 +[14:08:38.602] iteration 33200 [1687.67 sec]: learning rate : 0.000050 loss : 0.697992 +[14:10:04.913] iteration 33300 [1773.98 sec]: learning rate : 0.000050 loss : 0.572933 +[14:10:29.108] Epoch 15 Evaluation: +[14:11:20.543] average MSE: 0.058326516300439835 average PSNR: 27.90155719973212 average SSIM: 0.596409309111691 +[14:12:22.893] iteration 33400 [62.29 sec]: learning rate : 0.000050 loss : 0.512797 +[14:13:49.223] iteration 33500 [148.62 sec]: learning rate : 0.000050 loss : 1.108583 +[14:15:15.528] iteration 33600 [234.92 sec]: learning rate : 0.000050 loss : 0.285223 +[14:16:41.892] iteration 33700 [321.29 sec]: learning rate : 0.000050 loss : 0.519183 +[14:18:08.231] iteration 33800 [407.63 sec]: learning rate : 0.000050 loss : 0.753731 +[14:19:34.557] iteration 33900 [493.95 sec]: learning rate : 0.000050 loss : 0.477787 +[14:21:00.937] iteration 34000 [580.33 sec]: learning rate : 0.000050 loss : 0.394808 +[14:22:27.245] iteration 34100 [666.64 sec]: learning rate : 0.000050 loss : 0.577421 +[14:23:53.523] iteration 34200 [752.92 sec]: learning rate : 0.000050 loss : 0.415665 +[14:25:19.885] iteration 34300 [839.28 sec]: learning rate : 0.000050 loss : 0.504582 +[14:26:46.223] iteration 34400 [925.62 sec]: learning rate : 0.000050 loss : 0.619838 +[14:28:12.598] iteration 34500 [1011.99 sec]: learning rate : 0.000050 loss : 0.908552 +[14:29:38.934] iteration 34600 [1098.33 sec]: learning rate : 0.000050 loss : 0.859619 +[14:31:05.302] iteration 34700 [1184.70 sec]: learning rate : 0.000050 loss : 0.638848 +[14:32:31.632] iteration 34800 [1271.03 sec]: learning rate : 0.000050 loss : 1.215568 +[14:33:57.955] iteration 34900 [1357.35 sec]: learning rate : 0.000050 loss : 0.729145 +[14:35:24.328] iteration 35000 [1443.72 sec]: learning rate : 0.000050 loss : 0.977760 +[14:36:50.606] iteration 35100 [1530.00 sec]: learning rate : 0.000050 loss : 0.536514 +[14:38:16.890] iteration 35200 [1616.29 sec]: learning rate : 0.000050 loss : 0.674591 +[14:39:43.237] iteration 35300 [1702.63 sec]: learning rate : 0.000050 loss : 0.904846 +[14:41:09.561] iteration 35400 [1788.96 sec]: learning rate : 0.000050 loss : 0.726472 +[14:41:19.043] Epoch 16 Evaluation: +[14:42:10.997] average MSE: 0.05765249952673912 average PSNR: 27.96063197546215 average SSIM: 0.5969540251015694 +[14:43:28.186] iteration 35500 [77.13 sec]: learning rate : 0.000050 loss : 0.502791 +[14:44:54.506] iteration 35600 [163.45 sec]: learning rate : 0.000050 loss : 0.881457 +[14:46:20.863] iteration 35700 [249.80 sec]: learning rate : 0.000050 loss : 0.715742 +[14:47:47.271] iteration 35800 [336.21 sec]: learning rate : 0.000050 loss : 0.604992 +[14:49:13.561] iteration 35900 [422.50 sec]: learning rate : 0.000050 loss : 0.481134 +[14:50:39.914] iteration 36000 [508.85 sec]: learning rate : 0.000050 loss : 0.925958 +[14:52:06.239] iteration 36100 [595.18 sec]: learning rate : 0.000050 loss : 0.448116 +[14:53:32.578] iteration 36200 [681.52 sec]: learning rate : 0.000050 loss : 0.636963 +[14:54:58.957] iteration 36300 [767.90 sec]: learning rate : 0.000050 loss : 1.177419 +[14:56:25.276] iteration 36400 [854.22 sec]: learning rate : 0.000050 loss : 0.583580 +[14:57:51.687] iteration 36500 [940.63 sec]: learning rate : 0.000050 loss : 0.515885 +[14:59:18.087] iteration 36600 [1027.03 sec]: learning rate : 0.000050 loss : 0.579513 +[15:00:44.437] iteration 36700 [1113.38 sec]: learning rate : 0.000050 loss : 0.628145 +[15:02:10.831] iteration 36800 [1199.77 sec]: learning rate : 0.000050 loss : 0.681086 +[15:03:37.177] iteration 36900 [1286.12 sec]: learning rate : 0.000050 loss : 0.551118 +[15:05:03.481] iteration 37000 [1372.42 sec]: learning rate : 0.000050 loss : 0.566175 +[15:06:29.861] iteration 37100 [1458.80 sec]: learning rate : 0.000050 loss : 0.642014 +[15:07:56.180] iteration 37200 [1545.12 sec]: learning rate : 0.000050 loss : 0.438644 +[15:09:22.578] iteration 37300 [1631.52 sec]: learning rate : 0.000050 loss : 0.523392 +[15:10:48.998] iteration 37400 [1717.94 sec]: learning rate : 0.000050 loss : 0.531398 +[15:12:10.125] Epoch 17 Evaluation: +[15:13:01.656] average MSE: 0.0579269640147686 average PSNR: 27.939487857888512 average SSIM: 0.5971589408761954 +[15:13:07.096] iteration 37500 [5.38 sec]: learning rate : 0.000050 loss : 0.421281 +[15:14:33.475] iteration 37600 [91.76 sec]: learning rate : 0.000050 loss : 0.520148 +[15:15:59.859] iteration 37700 [178.14 sec]: learning rate : 0.000050 loss : 0.614690 +[15:17:26.179] iteration 37800 [264.46 sec]: learning rate : 0.000050 loss : 1.287497 +[15:18:52.574] iteration 37900 [350.86 sec]: learning rate : 0.000050 loss : 0.715827 +[15:20:18.974] iteration 38000 [437.26 sec]: learning rate : 0.000050 loss : 0.687019 +[15:21:45.265] iteration 38100 [523.55 sec]: learning rate : 0.000050 loss : 0.657023 +[15:23:11.658] iteration 38200 [609.94 sec]: learning rate : 0.000050 loss : 0.618956 +[15:24:37.981] iteration 38300 [696.26 sec]: learning rate : 0.000050 loss : 0.565679 +[15:26:04.330] iteration 38400 [782.61 sec]: learning rate : 0.000050 loss : 0.819697 +[15:27:30.670] iteration 38500 [868.95 sec]: learning rate : 0.000050 loss : 0.599194 +[15:28:56.976] iteration 38600 [955.26 sec]: learning rate : 0.000050 loss : 0.527983 +[15:30:23.364] iteration 38700 [1041.65 sec]: learning rate : 0.000050 loss : 0.549797 +[15:31:49.778] iteration 38800 [1128.06 sec]: learning rate : 0.000050 loss : 0.552740 +[15:33:16.116] iteration 38900 [1214.40 sec]: learning rate : 0.000050 loss : 0.524228 +[15:34:42.497] iteration 39000 [1300.78 sec]: learning rate : 0.000050 loss : 0.734961 +[15:36:08.915] iteration 39100 [1387.20 sec]: learning rate : 0.000050 loss : 0.448988 +[15:37:35.217] iteration 39200 [1473.50 sec]: learning rate : 0.000050 loss : 0.595631 +[15:39:01.559] iteration 39300 [1559.84 sec]: learning rate : 0.000050 loss : 0.866648 +[15:40:27.919] iteration 39400 [1646.20 sec]: learning rate : 0.000050 loss : 0.613500 +[15:41:54.207] iteration 39500 [1732.49 sec]: learning rate : 0.000050 loss : 0.719005 +[15:43:00.671] Epoch 18 Evaluation: +[15:43:51.219] average MSE: 0.057222187519073486 average PSNR: 27.99965454138081 average SSIM: 0.5970050292364608 +[15:44:11.317] iteration 39600 [20.04 sec]: learning rate : 0.000050 loss : 0.351625 +[15:45:37.692] iteration 39700 [106.41 sec]: learning rate : 0.000050 loss : 0.680003 +[15:47:04.162] iteration 39800 [192.88 sec]: learning rate : 0.000050 loss : 0.561453 +[15:48:30.569] iteration 39900 [279.29 sec]: learning rate : 0.000050 loss : 0.372058 +[15:49:56.884] iteration 40000 [365.60 sec]: learning rate : 0.000013 loss : 0.468863 +[15:49:57.045] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_40000.pth +[15:51:23.410] iteration 40100 [452.13 sec]: learning rate : 0.000025 loss : 0.889163 +[15:52:49.818] iteration 40200 [538.54 sec]: learning rate : 0.000025 loss : 0.700554 +[15:54:16.107] iteration 40300 [624.83 sec]: learning rate : 0.000025 loss : 0.789185 +[15:55:42.456] iteration 40400 [711.17 sec]: learning rate : 0.000025 loss : 0.664141 +[15:57:08.801] iteration 40500 [797.52 sec]: learning rate : 0.000025 loss : 0.358813 +[15:58:35.099] iteration 40600 [883.82 sec]: learning rate : 0.000025 loss : 0.760751 +[16:00:01.478] iteration 40700 [970.20 sec]: learning rate : 0.000025 loss : 0.714375 +[16:01:27.853] iteration 40800 [1056.57 sec]: learning rate : 0.000025 loss : 0.419001 +[16:02:54.141] iteration 40900 [1142.86 sec]: learning rate : 0.000025 loss : 0.865106 +[16:04:20.500] iteration 41000 [1229.22 sec]: learning rate : 0.000025 loss : 0.557978 +[16:05:46.852] iteration 41100 [1315.57 sec]: learning rate : 0.000025 loss : 0.798250 +[16:07:13.128] iteration 41200 [1401.85 sec]: learning rate : 0.000025 loss : 0.519366 +[16:08:39.504] iteration 41300 [1488.22 sec]: learning rate : 0.000025 loss : 0.717059 +[16:10:05.867] iteration 41400 [1574.59 sec]: learning rate : 0.000025 loss : 0.814453 +[16:11:32.193] iteration 41500 [1660.91 sec]: learning rate : 0.000025 loss : 0.956431 +[16:12:58.538] iteration 41600 [1747.26 sec]: learning rate : 0.000025 loss : 0.698550 +[16:13:50.299] Epoch 19 Evaluation: +[16:14:41.927] average MSE: 0.05757247284054756 average PSNR: 28.004004389652312 average SSIM: 0.5981279733393106 +[16:15:16.662] iteration 41700 [34.67 sec]: learning rate : 0.000025 loss : 0.530446 +[16:16:42.956] iteration 41800 [120.97 sec]: learning rate : 0.000025 loss : 0.764648 +[16:18:09.350] iteration 41900 [207.36 sec]: learning rate : 0.000025 loss : 0.412423 +[16:19:35.721] iteration 42000 [293.73 sec]: learning rate : 0.000025 loss : 0.561935 +[16:21:02.047] iteration 42100 [380.06 sec]: learning rate : 0.000025 loss : 0.590327 +[16:22:28.359] iteration 42200 [466.37 sec]: learning rate : 0.000025 loss : 0.837639 +[16:23:54.622] iteration 42300 [552.63 sec]: learning rate : 0.000025 loss : 0.761523 +[16:25:21.015] iteration 42400 [639.02 sec]: learning rate : 0.000025 loss : 0.413751 +[16:26:47.378] iteration 42500 [725.39 sec]: learning rate : 0.000025 loss : 0.339070 +[16:28:13.722] iteration 42600 [811.73 sec]: learning rate : 0.000025 loss : 0.536881 +[16:29:40.073] iteration 42700 [898.08 sec]: learning rate : 0.000025 loss : 0.647579 +[16:31:06.366] iteration 42800 [984.38 sec]: learning rate : 0.000025 loss : 0.797593 +[16:32:32.680] iteration 42900 [1070.69 sec]: learning rate : 0.000025 loss : 0.473208 +[16:33:59.059] iteration 43000 [1157.07 sec]: learning rate : 0.000025 loss : 0.406572 +[16:35:25.405] iteration 43100 [1243.41 sec]: learning rate : 0.000025 loss : 0.562954 +[16:36:51.727] iteration 43200 [1329.74 sec]: learning rate : 0.000025 loss : 0.765770 +[16:38:18.025] iteration 43300 [1416.04 sec]: learning rate : 0.000025 loss : 0.622599 +[16:39:44.418] iteration 43400 [1502.43 sec]: learning rate : 0.000025 loss : 0.679205 +[16:41:10.746] iteration 43500 [1588.76 sec]: learning rate : 0.000025 loss : 0.923160 +[16:42:37.075] iteration 43600 [1675.09 sec]: learning rate : 0.000025 loss : 0.337802 +[16:44:03.453] iteration 43700 [1761.46 sec]: learning rate : 0.000025 loss : 0.761516 +[16:44:40.548] Epoch 20 Evaluation: +[16:45:30.012] average MSE: 0.05659179762005806 average PSNR: 28.069950289347403 average SSIM: 0.5977004541529017 +[16:46:19.537] iteration 43800 [49.46 sec]: learning rate : 0.000025 loss : 1.234639 +[16:47:45.871] iteration 43900 [135.80 sec]: learning rate : 0.000025 loss : 0.408570 +[16:49:12.209] iteration 44000 [222.13 sec]: learning rate : 0.000025 loss : 0.656590 +[16:50:38.542] iteration 44100 [308.47 sec]: learning rate : 0.000025 loss : 0.921847 +[16:52:04.909] iteration 44200 [394.83 sec]: learning rate : 0.000025 loss : 0.465411 +[16:53:31.209] iteration 44300 [481.13 sec]: learning rate : 0.000025 loss : 0.868429 +[16:54:57.500] iteration 44400 [567.43 sec]: learning rate : 0.000025 loss : 0.469303 +[16:56:23.829] iteration 44500 [653.75 sec]: learning rate : 0.000025 loss : 0.610244 +[16:57:50.111] iteration 44600 [740.03 sec]: learning rate : 0.000025 loss : 0.521616 +[16:59:16.468] iteration 44700 [826.39 sec]: learning rate : 0.000025 loss : 0.678121 +[17:00:42.850] iteration 44800 [912.77 sec]: learning rate : 0.000025 loss : 0.706359 +[17:02:09.177] iteration 44900 [999.10 sec]: learning rate : 0.000025 loss : 0.915983 +[17:03:35.493] iteration 45000 [1085.42 sec]: learning rate : 0.000025 loss : 0.943585 +[17:05:01.829] iteration 45100 [1171.77 sec]: learning rate : 0.000025 loss : 0.577550 +[17:06:28.194] iteration 45200 [1258.12 sec]: learning rate : 0.000025 loss : 0.461367 +[17:07:54.568] iteration 45300 [1344.50 sec]: learning rate : 0.000025 loss : 0.404565 +[17:09:20.902] iteration 45400 [1430.83 sec]: learning rate : 0.000025 loss : 0.600068 +[17:10:47.281] iteration 45500 [1517.21 sec]: learning rate : 0.000025 loss : 0.401938 +[17:12:13.612] iteration 45600 [1603.54 sec]: learning rate : 0.000025 loss : 0.778972 +[17:13:39.900] iteration 45700 [1689.82 sec]: learning rate : 0.000025 loss : 0.481931 +[17:15:06.269] iteration 45800 [1776.19 sec]: learning rate : 0.000025 loss : 0.523598 +[17:15:28.686] Epoch 21 Evaluation: +[17:16:18.930] average MSE: 0.05687125399708748 average PSNR: 28.06132173730316 average SSIM: 0.5992032564311214 +[17:17:22.994] iteration 45900 [64.00 sec]: learning rate : 0.000025 loss : 0.829956 +[17:18:49.393] iteration 46000 [150.40 sec]: learning rate : 0.000025 loss : 0.670588 +[17:20:15.702] iteration 46100 [236.71 sec]: learning rate : 0.000025 loss : 0.423019 +[17:21:42.003] iteration 46200 [323.01 sec]: learning rate : 0.000025 loss : 0.769088 +[17:23:08.362] iteration 46300 [409.37 sec]: learning rate : 0.000025 loss : 0.896693 +[17:24:34.703] iteration 46400 [495.71 sec]: learning rate : 0.000025 loss : 0.698502 +[17:26:01.047] iteration 46500 [582.05 sec]: learning rate : 0.000025 loss : 0.326345 +[17:27:27.424] iteration 46600 [668.43 sec]: learning rate : 0.000025 loss : 0.438863 +[17:28:53.747] iteration 46700 [754.75 sec]: learning rate : 0.000025 loss : 0.542659 +[17:30:20.059] iteration 46800 [841.07 sec]: learning rate : 0.000025 loss : 0.537644 +[17:31:46.407] iteration 46900 [927.41 sec]: learning rate : 0.000025 loss : 0.382106 +[17:33:12.683] iteration 47000 [1013.69 sec]: learning rate : 0.000025 loss : 0.663902 +[17:34:39.058] iteration 47100 [1100.06 sec]: learning rate : 0.000025 loss : 0.495871 +[17:36:05.378] iteration 47200 [1186.39 sec]: learning rate : 0.000025 loss : 0.499109 +[17:37:31.777] iteration 47300 [1272.78 sec]: learning rate : 0.000025 loss : 0.904925 +[17:38:58.179] iteration 47400 [1359.19 sec]: learning rate : 0.000025 loss : 0.482162 +[17:40:24.513] iteration 47500 [1445.52 sec]: learning rate : 0.000025 loss : 0.505313 +[17:41:50.900] iteration 47600 [1531.91 sec]: learning rate : 0.000025 loss : 0.332335 +[17:43:17.259] iteration 47700 [1618.27 sec]: learning rate : 0.000025 loss : 0.653883 +[17:44:43.515] iteration 47800 [1704.52 sec]: learning rate : 0.000025 loss : 0.806563 +[17:46:09.873] iteration 47900 [1790.88 sec]: learning rate : 0.000025 loss : 0.270932 +[17:46:17.611] Epoch 22 Evaluation: +[17:47:06.831] average MSE: 0.05681406334042549 average PSNR: 28.067959372478928 average SSIM: 0.5989247813601238 +[17:48:25.650] iteration 48000 [78.76 sec]: learning rate : 0.000025 loss : 0.599660 +[17:49:51.934] iteration 48100 [165.04 sec]: learning rate : 0.000025 loss : 0.426478 +[17:51:18.281] iteration 48200 [251.39 sec]: learning rate : 0.000025 loss : 0.512170 +[17:52:44.607] iteration 48300 [337.71 sec]: learning rate : 0.000025 loss : 0.601772 +[17:54:10.994] iteration 48400 [424.10 sec]: learning rate : 0.000025 loss : 0.491191 +[17:55:37.382] iteration 48500 [510.49 sec]: learning rate : 0.000025 loss : 0.995776 +[17:57:03.708] iteration 48600 [596.81 sec]: learning rate : 0.000025 loss : 0.362773 +[17:58:30.094] iteration 48700 [683.20 sec]: learning rate : 0.000025 loss : 0.857118 +[17:59:56.426] iteration 48800 [769.53 sec]: learning rate : 0.000025 loss : 0.816315 +[18:01:22.699] iteration 48900 [855.81 sec]: learning rate : 0.000025 loss : 0.386976 +[18:02:49.094] iteration 49000 [942.20 sec]: learning rate : 0.000025 loss : 0.534402 +[18:04:15.362] iteration 49100 [1028.47 sec]: learning rate : 0.000025 loss : 0.410800 +[18:05:41.727] iteration 49200 [1114.83 sec]: learning rate : 0.000025 loss : 0.611507 +[18:07:08.110] iteration 49300 [1201.22 sec]: learning rate : 0.000025 loss : 0.610853 +[18:08:34.447] iteration 49400 [1287.55 sec]: learning rate : 0.000025 loss : 0.438207 +[18:10:00.830] iteration 49500 [1373.94 sec]: learning rate : 0.000025 loss : 0.540036 +[18:11:27.100] iteration 49600 [1460.21 sec]: learning rate : 0.000025 loss : 0.863025 +[18:12:53.400] iteration 49700 [1546.51 sec]: learning rate : 0.000025 loss : 0.821945 +[18:14:19.786] iteration 49800 [1632.89 sec]: learning rate : 0.000025 loss : 0.577021 +[18:15:46.112] iteration 49900 [1719.22 sec]: learning rate : 0.000025 loss : 0.674080 +[18:17:05.550] Epoch 23 Evaluation: +[18:17:54.808] average MSE: 0.05716777965426445 average PSNR: 28.026371082299974 average SSIM: 0.5983756903609689 +[18:18:01.952] iteration 50000 [7.08 sec]: learning rate : 0.000025 loss : 0.374507 +[18:19:28.382] iteration 50100 [93.51 sec]: learning rate : 0.000025 loss : 0.400630 +[18:20:54.700] iteration 50200 [179.83 sec]: learning rate : 0.000025 loss : 0.552593 +[18:22:21.015] iteration 50300 [266.14 sec]: learning rate : 0.000025 loss : 1.079659 +[18:23:47.300] iteration 50400 [352.43 sec]: learning rate : 0.000025 loss : 0.945370 +[18:25:13.564] iteration 50500 [438.69 sec]: learning rate : 0.000025 loss : 0.849402 +[18:26:39.926] iteration 50600 [525.05 sec]: learning rate : 0.000025 loss : 0.620571 +[18:28:06.269] iteration 50700 [611.40 sec]: learning rate : 0.000025 loss : 0.477008 +[18:29:32.541] iteration 50800 [697.67 sec]: learning rate : 0.000025 loss : 0.565940 +[18:30:58.909] iteration 50900 [784.04 sec]: learning rate : 0.000025 loss : 0.661832 +[18:32:25.321] iteration 51000 [870.45 sec]: learning rate : 0.000025 loss : 0.886061 +[18:33:51.630] iteration 51100 [956.76 sec]: learning rate : 0.000025 loss : 0.433378 +[18:35:18.026] iteration 51200 [1043.16 sec]: learning rate : 0.000025 loss : 0.674052 +[18:36:44.391] iteration 51300 [1129.52 sec]: learning rate : 0.000025 loss : 0.324643 +[18:38:10.748] iteration 51400 [1215.88 sec]: learning rate : 0.000025 loss : 0.473934 +[18:39:37.148] iteration 51500 [1302.28 sec]: learning rate : 0.000025 loss : 0.734336 +[18:41:03.576] iteration 51600 [1388.70 sec]: learning rate : 0.000025 loss : 0.575684 +[18:42:29.924] iteration 51700 [1475.05 sec]: learning rate : 0.000025 loss : 0.437653 +[18:43:56.245] iteration 51800 [1561.37 sec]: learning rate : 0.000025 loss : 0.807283 +[18:45:22.601] iteration 51900 [1647.73 sec]: learning rate : 0.000025 loss : 0.835059 +[18:46:48.893] iteration 52000 [1734.02 sec]: learning rate : 0.000025 loss : 0.787902 +[18:47:53.674] Epoch 24 Evaluation: +[18:48:44.356] average MSE: 0.0568385124206543 average PSNR: 28.067761616210046 average SSIM: 0.5996529300479186 +[18:49:06.158] iteration 52100 [21.74 sec]: learning rate : 0.000025 loss : 0.610308 +[18:50:32.540] iteration 52200 [108.12 sec]: learning rate : 0.000025 loss : 0.497311 +[18:51:58.850] iteration 52300 [194.43 sec]: learning rate : 0.000025 loss : 0.399685 +[18:53:25.257] iteration 52400 [280.84 sec]: learning rate : 0.000025 loss : 0.644999 +[18:54:51.647] iteration 52500 [367.23 sec]: learning rate : 0.000025 loss : 0.588668 +[18:56:17.949] iteration 52600 [453.53 sec]: learning rate : 0.000025 loss : 0.709082 +[18:57:44.329] iteration 52700 [539.91 sec]: learning rate : 0.000025 loss : 0.572137 +[18:59:10.682] iteration 52800 [626.26 sec]: learning rate : 0.000025 loss : 0.624935 +[19:00:37.079] iteration 52900 [712.66 sec]: learning rate : 0.000025 loss : 0.534611 +[19:02:03.475] iteration 53000 [799.06 sec]: learning rate : 0.000025 loss : 0.929789 +[19:03:29.810] iteration 53100 [885.39 sec]: learning rate : 0.000025 loss : 0.541529 +[19:04:56.186] iteration 53200 [971.77 sec]: learning rate : 0.000025 loss : 0.816715 +[19:06:22.552] iteration 53300 [1058.13 sec]: learning rate : 0.000025 loss : 0.859607 +[19:07:48.954] iteration 53400 [1144.54 sec]: learning rate : 0.000025 loss : 0.451213 +[19:09:15.328] iteration 53500 [1230.91 sec]: learning rate : 0.000025 loss : 0.990314 +[19:10:41.693] iteration 53600 [1317.27 sec]: learning rate : 0.000025 loss : 0.403030 +[19:12:08.083] iteration 53700 [1403.66 sec]: learning rate : 0.000025 loss : 0.730154 +[19:13:34.445] iteration 53800 [1490.03 sec]: learning rate : 0.000025 loss : 0.551138 +[19:15:00.846] iteration 53900 [1576.43 sec]: learning rate : 0.000025 loss : 0.385206 +[19:16:27.250] iteration 54000 [1662.83 sec]: learning rate : 0.000025 loss : 0.358180 +[19:17:53.587] iteration 54100 [1749.17 sec]: learning rate : 0.000025 loss : 0.305134 +[19:18:43.700] Epoch 25 Evaluation: +[19:19:34.701] average MSE: 0.0564347505569458 average PSNR: 28.092498141062062 average SSIM: 0.5994645529183728 +[19:20:11.170] iteration 54200 [36.41 sec]: learning rate : 0.000025 loss : 0.571173 +[19:21:37.568] iteration 54300 [122.80 sec]: learning rate : 0.000025 loss : 0.450204 +[19:23:03.885] iteration 54400 [209.12 sec]: learning rate : 0.000025 loss : 0.670893 +[19:24:30.236] iteration 54500 [295.47 sec]: learning rate : 0.000025 loss : 0.754853 +[19:25:56.588] iteration 54600 [381.82 sec]: learning rate : 0.000025 loss : 0.626832 +[19:27:22.870] iteration 54700 [468.11 sec]: learning rate : 0.000025 loss : 0.436201 +[19:28:49.209] iteration 54800 [554.45 sec]: learning rate : 0.000025 loss : 0.822387 +[19:30:15.472] iteration 54900 [640.71 sec]: learning rate : 0.000025 loss : 0.463992 +[19:31:41.857] iteration 55000 [727.10 sec]: learning rate : 0.000025 loss : 0.683489 +[19:33:08.172] iteration 55100 [813.41 sec]: learning rate : 0.000025 loss : 0.637593 +[19:34:34.433] iteration 55200 [899.67 sec]: learning rate : 0.000025 loss : 0.292266 +[19:36:00.765] iteration 55300 [986.00 sec]: learning rate : 0.000025 loss : 0.393203 +[19:37:27.073] iteration 55400 [1072.31 sec]: learning rate : 0.000025 loss : 1.017181 +[19:38:53.395] iteration 55500 [1158.63 sec]: learning rate : 0.000025 loss : 0.788302 +[19:40:19.702] iteration 55600 [1244.94 sec]: learning rate : 0.000025 loss : 0.799066 +[19:41:46.004] iteration 55700 [1331.24 sec]: learning rate : 0.000025 loss : 0.580418 +[19:43:12.381] iteration 55800 [1417.62 sec]: learning rate : 0.000025 loss : 0.719859 +[19:44:38.714] iteration 55900 [1503.95 sec]: learning rate : 0.000025 loss : 0.668439 +[19:46:04.982] iteration 56000 [1590.22 sec]: learning rate : 0.000025 loss : 0.804553 +[19:47:31.315] iteration 56100 [1676.55 sec]: learning rate : 0.000025 loss : 0.593095 +[19:48:57.662] iteration 56200 [1762.90 sec]: learning rate : 0.000025 loss : 0.669780 +[19:49:33.009] Epoch 26 Evaluation: +[19:50:22.618] average MSE: 0.05635973811149597 average PSNR: 28.10351637291031 average SSIM: 0.6007896011132333 +[19:51:13.788] iteration 56300 [51.11 sec]: learning rate : 0.000025 loss : 0.672126 +[19:52:40.197] iteration 56400 [137.52 sec]: learning rate : 0.000025 loss : 0.512896 +[19:54:06.544] iteration 56500 [223.86 sec]: learning rate : 0.000025 loss : 0.723107 +[19:55:32.891] iteration 56600 [310.21 sec]: learning rate : 0.000025 loss : 0.701434 +[19:56:59.259] iteration 56700 [396.58 sec]: learning rate : 0.000025 loss : 0.392383 +[19:58:25.649] iteration 56800 [482.97 sec]: learning rate : 0.000025 loss : 0.623445 +[19:59:52.002] iteration 56900 [569.32 sec]: learning rate : 0.000025 loss : 0.641297 +[20:01:18.390] iteration 57000 [655.71 sec]: learning rate : 0.000025 loss : 0.580175 +[20:02:44.723] iteration 57100 [742.04 sec]: learning rate : 0.000025 loss : 0.523631 +[20:04:11.122] iteration 57200 [828.44 sec]: learning rate : 0.000025 loss : 0.579515 +[20:05:37.445] iteration 57300 [914.77 sec]: learning rate : 0.000025 loss : 0.694875 +[20:07:03.733] iteration 57400 [1001.05 sec]: learning rate : 0.000025 loss : 0.479762 +[20:08:30.068] iteration 57500 [1087.39 sec]: learning rate : 0.000025 loss : 0.727026 +[20:09:56.475] iteration 57600 [1173.79 sec]: learning rate : 0.000025 loss : 0.528517 +[20:11:22.824] iteration 57700 [1260.14 sec]: learning rate : 0.000025 loss : 0.562245 +[20:12:49.249] iteration 57800 [1346.57 sec]: learning rate : 0.000025 loss : 0.334714 +[20:14:15.581] iteration 57900 [1432.90 sec]: learning rate : 0.000025 loss : 0.495926 +[20:15:41.862] iteration 58000 [1519.18 sec]: learning rate : 0.000025 loss : 0.452643 +[20:17:08.212] iteration 58100 [1605.53 sec]: learning rate : 0.000025 loss : 1.198597 +[20:18:34.547] iteration 58200 [1691.87 sec]: learning rate : 0.000025 loss : 0.566201 +[20:20:00.892] iteration 58300 [1778.21 sec]: learning rate : 0.000025 loss : 0.435220 +[20:20:21.585] Epoch 27 Evaluation: +[20:21:12.586] average MSE: 0.056389112025499344 average PSNR: 28.114000365517665 average SSIM: 0.6003866586980552 +[20:22:18.492] iteration 58400 [65.84 sec]: learning rate : 0.000025 loss : 0.704777 +[20:23:44.811] iteration 58500 [152.16 sec]: learning rate : 0.000025 loss : 0.851851 +[20:25:11.131] iteration 58600 [238.48 sec]: learning rate : 0.000025 loss : 0.602766 +[20:26:37.508] iteration 58700 [324.86 sec]: learning rate : 0.000025 loss : 0.865654 +[20:28:03.829] iteration 58800 [411.18 sec]: learning rate : 0.000025 loss : 0.808106 +[20:29:30.203] iteration 58900 [497.55 sec]: learning rate : 0.000025 loss : 0.906918 +[20:30:56.577] iteration 59000 [583.93 sec]: learning rate : 0.000025 loss : 0.365097 +[20:32:22.908] iteration 59100 [670.26 sec]: learning rate : 0.000025 loss : 0.360995 +[20:33:49.250] iteration 59200 [756.60 sec]: learning rate : 0.000025 loss : 0.683375 +[20:35:15.626] iteration 59300 [842.98 sec]: learning rate : 0.000025 loss : 0.502525 +[20:36:41.944] iteration 59400 [929.30 sec]: learning rate : 0.000025 loss : 0.372384 +[20:38:08.321] iteration 59500 [1015.67 sec]: learning rate : 0.000025 loss : 0.438998 +[20:39:34.708] iteration 59600 [1102.06 sec]: learning rate : 0.000025 loss : 0.672800 +[20:41:01.075] iteration 59700 [1188.43 sec]: learning rate : 0.000025 loss : 0.522849 +[20:42:27.485] iteration 59800 [1274.84 sec]: learning rate : 0.000025 loss : 0.444191 +[20:43:53.850] iteration 59900 [1361.20 sec]: learning rate : 0.000025 loss : 0.367626 +[20:45:20.184] iteration 60000 [1447.54 sec]: learning rate : 0.000006 loss : 0.899959 +[20:45:20.341] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_60000.pth +[20:46:46.752] iteration 60100 [1534.10 sec]: learning rate : 0.000013 loss : 0.420051 +[20:48:13.119] iteration 60200 [1620.47 sec]: learning rate : 0.000013 loss : 0.668864 +[20:49:39.447] iteration 60300 [1706.80 sec]: learning rate : 0.000013 loss : 0.645890 +[20:51:05.782] iteration 60400 [1793.13 sec]: learning rate : 0.000013 loss : 0.335497 +[20:51:11.800] Epoch 28 Evaluation: +[20:52:01.339] average MSE: 0.05619518831372261 average PSNR: 28.116767542809228 average SSIM: 0.6003082456851997 +[20:53:21.804] iteration 60500 [80.40 sec]: learning rate : 0.000013 loss : 0.513686 +[20:54:48.182] iteration 60600 [166.78 sec]: learning rate : 0.000013 loss : 0.516224 +[20:56:14.521] iteration 60700 [253.12 sec]: learning rate : 0.000013 loss : 0.434842 +[20:57:40.825] iteration 60800 [339.42 sec]: learning rate : 0.000013 loss : 0.587013 +[20:59:07.231] iteration 60900 [425.83 sec]: learning rate : 0.000013 loss : 0.495428 +[21:00:33.628] iteration 61000 [512.23 sec]: learning rate : 0.000013 loss : 0.596224 +[21:01:59.967] iteration 61100 [598.56 sec]: learning rate : 0.000013 loss : 0.767727 +[21:03:26.357] iteration 61200 [684.96 sec]: learning rate : 0.000013 loss : 0.758707 +[21:04:52.756] iteration 61300 [771.35 sec]: learning rate : 0.000013 loss : 0.413874 +[21:06:19.084] iteration 61400 [857.68 sec]: learning rate : 0.000013 loss : 0.308358 +[21:07:45.521] iteration 61500 [944.12 sec]: learning rate : 0.000013 loss : 0.590468 +[21:09:11.871] iteration 61600 [1030.47 sec]: learning rate : 0.000013 loss : 0.673886 +[21:10:38.280] iteration 61700 [1116.88 sec]: learning rate : 0.000013 loss : 0.626187 +[21:12:04.691] iteration 61800 [1203.29 sec]: learning rate : 0.000013 loss : 0.347859 +[21:13:31.037] iteration 61900 [1289.64 sec]: learning rate : 0.000013 loss : 0.504482 +[21:14:57.417] iteration 62000 [1376.01 sec]: learning rate : 0.000013 loss : 0.518345 +[21:16:23.770] iteration 62100 [1462.37 sec]: learning rate : 0.000013 loss : 0.577231 +[21:17:50.070] iteration 62200 [1548.67 sec]: learning rate : 0.000013 loss : 0.716170 +[21:19:16.477] iteration 62300 [1635.07 sec]: learning rate : 0.000013 loss : 0.834360 +[21:20:42.894] iteration 62400 [1721.49 sec]: learning rate : 0.000013 loss : 0.590074 +[21:22:00.552] Epoch 29 Evaluation: +[21:22:50.054] average MSE: 0.05592525005340576 average PSNR: 28.14547380794146 average SSIM: 0.6007280102750981 +[21:22:58.935] iteration 62500 [8.82 sec]: learning rate : 0.000013 loss : 0.829882 +[21:24:25.325] iteration 62600 [95.21 sec]: learning rate : 0.000013 loss : 0.603118 +[21:25:51.699] iteration 62700 [181.58 sec]: learning rate : 0.000013 loss : 0.783261 +[21:27:18.014] iteration 62800 [267.90 sec]: learning rate : 0.000013 loss : 0.328847 +[21:28:44.394] iteration 62900 [354.28 sec]: learning rate : 0.000013 loss : 0.452125 +[21:30:10.772] iteration 63000 [440.65 sec]: learning rate : 0.000013 loss : 0.644129 +[21:31:37.119] iteration 63100 [527.00 sec]: learning rate : 0.000013 loss : 0.551184 +[21:33:03.513] iteration 63200 [613.40 sec]: learning rate : 0.000013 loss : 0.968010 +[21:34:29.907] iteration 63300 [699.79 sec]: learning rate : 0.000013 loss : 0.601891 +[21:35:56.247] iteration 63400 [786.13 sec]: learning rate : 0.000013 loss : 0.392574 +[21:37:22.642] iteration 63500 [872.53 sec]: learning rate : 0.000013 loss : 0.507656 +[21:38:49.035] iteration 63600 [958.92 sec]: learning rate : 0.000013 loss : 0.349420 +[21:40:15.365] iteration 63700 [1045.25 sec]: learning rate : 0.000013 loss : 0.462920 +[21:41:41.737] iteration 63800 [1131.62 sec]: learning rate : 0.000013 loss : 0.615645 +[21:43:08.043] iteration 63900 [1217.93 sec]: learning rate : 0.000013 loss : 0.444645 +[21:44:34.343] iteration 64000 [1304.23 sec]: learning rate : 0.000013 loss : 1.044305 +[21:46:00.692] iteration 64100 [1390.58 sec]: learning rate : 0.000013 loss : 0.637985 +[21:47:27.082] iteration 64200 [1476.97 sec]: learning rate : 0.000013 loss : 0.572424 +[21:48:53.416] iteration 64300 [1563.30 sec]: learning rate : 0.000013 loss : 0.618718 +[21:50:19.790] iteration 64400 [1649.67 sec]: learning rate : 0.000013 loss : 0.837782 +[21:51:46.187] iteration 64500 [1736.07 sec]: learning rate : 0.000013 loss : 0.751763 +[21:52:49.126] Epoch 30 Evaluation: +[21:53:40.184] average MSE: 0.05626516416668892 average PSNR: 28.113280331506548 average SSIM: 0.6007781241881867 +[21:54:03.717] iteration 64600 [23.47 sec]: learning rate : 0.000013 loss : 0.548063 +[21:55:30.171] iteration 64700 [109.93 sec]: learning rate : 0.000013 loss : 0.581545 +[21:56:56.516] iteration 64800 [196.27 sec]: learning rate : 0.000013 loss : 0.435780 +[21:58:22.810] iteration 64900 [282.56 sec]: learning rate : 0.000013 loss : 0.657186 +[21:59:49.215] iteration 65000 [368.97 sec]: learning rate : 0.000013 loss : 0.778300 +[22:01:15.590] iteration 65100 [455.34 sec]: learning rate : 0.000013 loss : 0.799735 +[22:02:41.916] iteration 65200 [541.67 sec]: learning rate : 0.000013 loss : 0.824624 +[22:04:08.340] iteration 65300 [628.09 sec]: learning rate : 0.000013 loss : 0.567646 +[22:05:34.689] iteration 65400 [714.44 sec]: learning rate : 0.000013 loss : 0.529764 +[22:07:01.060] iteration 65500 [800.81 sec]: learning rate : 0.000013 loss : 0.778484 +[22:08:27.482] iteration 65600 [887.24 sec]: learning rate : 0.000013 loss : 0.502536 +[22:09:53.916] iteration 65700 [973.67 sec]: learning rate : 0.000013 loss : 0.520733 +[22:11:20.265] iteration 65800 [1060.02 sec]: learning rate : 0.000013 loss : 0.849569 +[22:12:46.648] iteration 65900 [1146.40 sec]: learning rate : 0.000013 loss : 0.456813 +[22:14:12.973] iteration 66000 [1232.73 sec]: learning rate : 0.000013 loss : 0.767876 +[22:15:39.267] iteration 66100 [1319.02 sec]: learning rate : 0.000013 loss : 0.591448 +[22:17:05.592] iteration 66200 [1405.34 sec]: learning rate : 0.000013 loss : 0.704272 +[22:18:31.927] iteration 66300 [1491.68 sec]: learning rate : 0.000013 loss : 0.849690 +[22:19:58.271] iteration 66400 [1578.02 sec]: learning rate : 0.000013 loss : 0.379496 +[22:21:24.599] iteration 66500 [1664.35 sec]: learning rate : 0.000013 loss : 0.570464 +[22:22:50.922] iteration 66600 [1750.68 sec]: learning rate : 0.000013 loss : 0.456289 +[22:23:39.206] Epoch 31 Evaluation: +[22:24:28.416] average MSE: 0.055789340287446976 average PSNR: 28.15768563628285 average SSIM: 0.6008403835562545 +[22:25:06.716] iteration 66700 [38.24 sec]: learning rate : 0.000013 loss : 0.670638 +[22:26:33.007] iteration 66800 [124.53 sec]: learning rate : 0.000013 loss : 0.461963 +[22:27:59.368] iteration 66900 [210.89 sec]: learning rate : 0.000013 loss : 0.671837 +[22:29:25.763] iteration 67000 [297.28 sec]: learning rate : 0.000013 loss : 0.727502 +[22:30:52.106] iteration 67100 [383.63 sec]: learning rate : 0.000013 loss : 0.308984 +[22:32:18.478] iteration 67200 [470.00 sec]: learning rate : 0.000013 loss : 0.545264 +[22:33:44.863] iteration 67300 [556.38 sec]: learning rate : 0.000013 loss : 0.486028 +[22:35:11.272] iteration 67400 [642.79 sec]: learning rate : 0.000013 loss : 0.649036 +[22:36:37.647] iteration 67500 [729.17 sec]: learning rate : 0.000013 loss : 0.817107 +[22:38:04.007] iteration 67600 [815.53 sec]: learning rate : 0.000013 loss : 0.490738 +[22:39:30.430] iteration 67700 [901.95 sec]: learning rate : 0.000013 loss : 0.361900 +[22:40:56.805] iteration 67800 [988.33 sec]: learning rate : 0.000013 loss : 0.311577 +[22:42:23.225] iteration 67900 [1074.75 sec]: learning rate : 0.000013 loss : 0.819089 +[22:43:49.619] iteration 68000 [1161.14 sec]: learning rate : 0.000013 loss : 0.507789 +[22:45:15.923] iteration 68100 [1247.44 sec]: learning rate : 0.000013 loss : 0.615019 +[22:46:42.283] iteration 68200 [1333.80 sec]: learning rate : 0.000013 loss : 0.932305 +[22:48:08.595] iteration 68300 [1420.12 sec]: learning rate : 0.000013 loss : 0.758586 +[22:49:34.939] iteration 68400 [1506.46 sec]: learning rate : 0.000013 loss : 0.840786 +[22:51:01.302] iteration 68500 [1592.82 sec]: learning rate : 0.000013 loss : 0.741051 +[22:52:27.657] iteration 68600 [1679.18 sec]: learning rate : 0.000013 loss : 0.608340 +[22:53:54.084] iteration 68700 [1765.61 sec]: learning rate : 0.000013 loss : 0.988158 +[22:54:27.730] Epoch 32 Evaluation: +[22:55:18.646] average MSE: 0.055753204971551895 average PSNR: 28.159244940350728 average SSIM: 0.6005058001635755 +[22:56:11.624] iteration 68800 [52.91 sec]: learning rate : 0.000013 loss : 1.134624 +[22:57:37.916] iteration 68900 [139.21 sec]: learning rate : 0.000013 loss : 0.891350 +[22:59:04.311] iteration 69000 [225.60 sec]: learning rate : 0.000013 loss : 0.694438 +[23:00:30.656] iteration 69100 [311.95 sec]: learning rate : 0.000013 loss : 0.614259 +[23:01:57.053] iteration 69200 [398.34 sec]: learning rate : 0.000013 loss : 0.576417 +[23:03:23.403] iteration 69300 [484.69 sec]: learning rate : 0.000013 loss : 0.622439 +[23:04:49.718] iteration 69400 [571.01 sec]: learning rate : 0.000013 loss : 0.540371 +[23:06:16.124] iteration 69500 [657.41 sec]: learning rate : 0.000013 loss : 0.492688 +[23:07:42.492] iteration 69600 [743.78 sec]: learning rate : 0.000013 loss : 0.639747 +[23:09:08.806] iteration 69700 [830.10 sec]: learning rate : 0.000013 loss : 0.303690 +[23:10:35.218] iteration 69800 [916.51 sec]: learning rate : 0.000013 loss : 0.515520 +[23:12:01.566] iteration 69900 [1002.86 sec]: learning rate : 0.000013 loss : 0.913951 +[23:13:27.942] iteration 70000 [1089.23 sec]: learning rate : 0.000013 loss : 0.772938 +[23:14:54.326] iteration 70100 [1175.62 sec]: learning rate : 0.000013 loss : 0.473070 +[23:16:20.679] iteration 70200 [1261.97 sec]: learning rate : 0.000013 loss : 0.555122 +[23:17:47.035] iteration 70300 [1348.33 sec]: learning rate : 0.000013 loss : 0.795626 +[23:19:13.385] iteration 70400 [1434.67 sec]: learning rate : 0.000013 loss : 0.280610 +[23:20:39.729] iteration 70500 [1521.02 sec]: learning rate : 0.000013 loss : 0.548207 +[23:22:06.104] iteration 70600 [1607.39 sec]: learning rate : 0.000013 loss : 0.425044 +[23:23:32.438] iteration 70700 [1693.73 sec]: learning rate : 0.000013 loss : 0.614458 +[23:24:58.830] iteration 70800 [1780.12 sec]: learning rate : 0.000013 loss : 0.451389 +[23:25:17.791] Epoch 33 Evaluation: +[23:26:09.211] average MSE: 0.05568737909197807 average PSNR: 28.15759661680794 average SSIM: 0.6011731134766177 +[23:27:16.863] iteration 70900 [67.59 sec]: learning rate : 0.000013 loss : 0.565970 +[23:28:43.159] iteration 71000 [153.89 sec]: learning rate : 0.000013 loss : 0.827929 +[23:30:09.532] iteration 71100 [240.26 sec]: learning rate : 0.000013 loss : 0.460453 +[23:31:35.886] iteration 71200 [326.61 sec]: learning rate : 0.000013 loss : 0.564502 +[23:33:02.179] iteration 71300 [412.91 sec]: learning rate : 0.000013 loss : 0.430109 +[23:34:28.522] iteration 71400 [499.25 sec]: learning rate : 0.000013 loss : 0.532710 +[23:35:54.880] iteration 71500 [585.61 sec]: learning rate : 0.000013 loss : 0.532754 +[23:37:21.180] iteration 71600 [671.91 sec]: learning rate : 0.000013 loss : 0.557462 +[23:38:47.516] iteration 71700 [758.24 sec]: learning rate : 0.000013 loss : 0.678324 +[23:40:13.786] iteration 71800 [844.51 sec]: learning rate : 0.000013 loss : 0.481705 +[23:41:40.140] iteration 71900 [930.87 sec]: learning rate : 0.000013 loss : 0.852941 +[23:43:06.486] iteration 72000 [1017.21 sec]: learning rate : 0.000013 loss : 0.678856 +[23:44:32.762] iteration 72100 [1103.49 sec]: learning rate : 0.000013 loss : 0.926136 +[23:45:59.168] iteration 72200 [1189.89 sec]: learning rate : 0.000013 loss : 0.552447 +[23:47:25.493] iteration 72300 [1276.22 sec]: learning rate : 0.000013 loss : 0.510137 +[23:48:51.807] iteration 72400 [1362.53 sec]: learning rate : 0.000013 loss : 0.552241 +[23:50:18.179] iteration 72500 [1448.90 sec]: learning rate : 0.000013 loss : 0.640253 +[23:51:44.570] iteration 72600 [1535.30 sec]: learning rate : 0.000013 loss : 0.390412 +[23:53:10.875] iteration 72700 [1621.61 sec]: learning rate : 0.000013 loss : 0.550902 +[23:54:37.211] iteration 72800 [1707.94 sec]: learning rate : 0.000013 loss : 0.576492 +[23:56:03.610] iteration 72900 [1794.34 sec]: learning rate : 0.000013 loss : 0.667425 +[23:56:07.903] Epoch 34 Evaluation: +[23:56:57.203] average MSE: 0.055668603628873825 average PSNR: 28.164459456214033 average SSIM: 0.60095296692346 +[23:58:19.441] iteration 73000 [82.18 sec]: learning rate : 0.000013 loss : 0.909840 +[23:59:45.867] iteration 73100 [168.60 sec]: learning rate : 0.000013 loss : 0.707929 +[00:01:12.242] iteration 73200 [254.98 sec]: learning rate : 0.000013 loss : 0.588089 +[00:02:38.567] iteration 73300 [341.30 sec]: learning rate : 0.000013 loss : 0.435457 +[00:04:04.932] iteration 73400 [427.67 sec]: learning rate : 0.000013 loss : 0.484258 +[00:05:31.310] iteration 73500 [514.04 sec]: learning rate : 0.000013 loss : 0.494866 +[00:06:57.715] iteration 73600 [600.45 sec]: learning rate : 0.000013 loss : 0.578789 +[00:08:24.136] iteration 73700 [686.87 sec]: learning rate : 0.000013 loss : 0.534729 +[00:09:50.485] iteration 73800 [773.22 sec]: learning rate : 0.000013 loss : 0.460183 +[00:11:16.888] iteration 73900 [859.62 sec]: learning rate : 0.000013 loss : 0.634385 +[00:12:43.307] iteration 74000 [946.04 sec]: learning rate : 0.000013 loss : 0.503162 +[00:14:09.718] iteration 74100 [1032.45 sec]: learning rate : 0.000013 loss : 0.847286 +[00:15:36.152] iteration 74200 [1118.89 sec]: learning rate : 0.000013 loss : 0.506029 +[00:17:02.533] iteration 74300 [1205.27 sec]: learning rate : 0.000013 loss : 0.538067 +[00:18:28.885] iteration 74400 [1291.62 sec]: learning rate : 0.000013 loss : 0.629792 +[00:19:55.305] iteration 74500 [1378.04 sec]: learning rate : 0.000013 loss : 0.367690 +[00:21:21.645] iteration 74600 [1464.38 sec]: learning rate : 0.000013 loss : 0.278065 +[00:22:48.017] iteration 74700 [1550.75 sec]: learning rate : 0.000013 loss : 0.553630 +[00:24:14.413] iteration 74800 [1637.15 sec]: learning rate : 0.000013 loss : 0.583044 +[00:25:40.818] iteration 74900 [1723.55 sec]: learning rate : 0.000013 loss : 0.548391 +[00:26:56.799] Epoch 35 Evaluation: +[00:27:48.541] average MSE: 0.05592691898345947 average PSNR: 28.144713675907422 average SSIM: 0.6012151935015658 +[00:27:59.129] iteration 75000 [10.53 sec]: learning rate : 0.000013 loss : 0.406057 +[00:29:25.440] iteration 75100 [96.84 sec]: learning rate : 0.000013 loss : 0.868090 +[00:30:51.835] iteration 75200 [183.23 sec]: learning rate : 0.000013 loss : 0.522652 +[00:32:18.125] iteration 75300 [269.52 sec]: learning rate : 0.000013 loss : 0.541537 +[00:33:44.461] iteration 75400 [355.86 sec]: learning rate : 0.000013 loss : 0.437796 +[00:35:10.818] iteration 75500 [442.21 sec]: learning rate : 0.000013 loss : 0.453514 +[00:36:37.139] iteration 75600 [528.53 sec]: learning rate : 0.000013 loss : 0.319442 +[00:38:03.459] iteration 75700 [614.85 sec]: learning rate : 0.000013 loss : 0.849106 +[00:39:29.738] iteration 75800 [701.13 sec]: learning rate : 0.000013 loss : 0.592475 +[00:40:56.003] iteration 75900 [787.40 sec]: learning rate : 0.000013 loss : 0.437093 +[00:42:22.329] iteration 76000 [873.72 sec]: learning rate : 0.000013 loss : 0.576493 +[00:43:48.573] iteration 76100 [959.97 sec]: learning rate : 0.000013 loss : 0.575530 +[00:45:14.882] iteration 76200 [1046.28 sec]: learning rate : 0.000013 loss : 0.398230 +[00:46:41.165] iteration 76300 [1132.56 sec]: learning rate : 0.000013 loss : 0.803274 +[00:48:07.425] iteration 76400 [1218.82 sec]: learning rate : 0.000013 loss : 0.616831 +[00:49:33.705] iteration 76500 [1305.10 sec]: learning rate : 0.000013 loss : 0.501268 +[00:50:59.928] iteration 76600 [1391.32 sec]: learning rate : 0.000013 loss : 0.904630 +[00:52:26.154] iteration 76700 [1477.55 sec]: learning rate : 0.000013 loss : 0.593830 +[00:53:52.418] iteration 76800 [1563.81 sec]: learning rate : 0.000013 loss : 0.443780 +[00:55:18.613] iteration 76900 [1650.01 sec]: learning rate : 0.000013 loss : 0.956888 +[00:56:44.910] iteration 77000 [1736.31 sec]: learning rate : 0.000013 loss : 0.610551 +[00:57:46.106] Epoch 36 Evaluation: +[00:58:35.539] average MSE: 0.055373415350914 average PSNR: 28.188367178810733 average SSIM: 0.6004604827550603 +[00:59:00.879] iteration 77100 [25.28 sec]: learning rate : 0.000013 loss : 0.621343 +[01:00:27.056] iteration 77200 [111.45 sec]: learning rate : 0.000013 loss : 0.369829 +[01:01:53.279] iteration 77300 [197.68 sec]: learning rate : 0.000013 loss : 0.497497 +[01:03:19.506] iteration 77400 [283.91 sec]: learning rate : 0.000013 loss : 0.366656 +[01:04:45.690] iteration 77500 [370.09 sec]: learning rate : 0.000013 loss : 0.697896 +[01:06:11.924] iteration 77600 [456.32 sec]: learning rate : 0.000013 loss : 0.772045 +[01:07:38.178] iteration 77700 [542.58 sec]: learning rate : 0.000013 loss : 0.692415 +[01:09:04.361] iteration 77800 [628.76 sec]: learning rate : 0.000013 loss : 0.545107 +[01:10:30.662] iteration 77900 [715.06 sec]: learning rate : 0.000013 loss : 0.621703 +[01:11:56.890] iteration 78000 [801.29 sec]: learning rate : 0.000013 loss : 0.540336 +[01:13:23.196] iteration 78100 [887.60 sec]: learning rate : 0.000013 loss : 0.485293 +[01:14:49.504] iteration 78200 [973.90 sec]: learning rate : 0.000013 loss : 0.715219 +[01:16:15.741] iteration 78300 [1060.14 sec]: learning rate : 0.000013 loss : 0.398479 +[01:17:42.039] iteration 78400 [1146.44 sec]: learning rate : 0.000013 loss : 0.731404 +[01:19:08.314] iteration 78500 [1232.71 sec]: learning rate : 0.000013 loss : 0.692473 +[01:20:34.548] iteration 78600 [1318.95 sec]: learning rate : 0.000013 loss : 0.444958 +[01:22:00.778] iteration 78700 [1405.18 sec]: learning rate : 0.000013 loss : 0.643852 +[01:23:27.069] iteration 78800 [1491.47 sec]: learning rate : 0.000013 loss : 0.558725 +[01:24:53.286] iteration 78900 [1577.69 sec]: learning rate : 0.000013 loss : 0.437634 +[01:26:19.581] iteration 79000 [1663.98 sec]: learning rate : 0.000013 loss : 0.305686 +[01:27:45.790] iteration 79100 [1750.19 sec]: learning rate : 0.000013 loss : 0.495392 +[01:28:32.377] Epoch 37 Evaluation: +[01:29:23.046] average MSE: 0.05568062141537666 average PSNR: 28.167079375728814 average SSIM: 0.6011339216871895 +[01:30:02.914] iteration 79200 [39.81 sec]: learning rate : 0.000013 loss : 0.325504 +[01:31:29.210] iteration 79300 [126.10 sec]: learning rate : 0.000013 loss : 0.589475 +[01:32:55.341] iteration 79400 [212.23 sec]: learning rate : 0.000013 loss : 1.015221 +[01:34:21.522] iteration 79500 [298.41 sec]: learning rate : 0.000013 loss : 0.575494 +[01:35:47.743] iteration 79600 [384.63 sec]: learning rate : 0.000013 loss : 1.139600 +[01:37:13.890] iteration 79700 [470.78 sec]: learning rate : 0.000013 loss : 0.600211 +[01:38:40.110] iteration 79800 [557.00 sec]: learning rate : 0.000013 loss : 0.775319 +[01:40:06.347] iteration 79900 [643.24 sec]: learning rate : 0.000013 loss : 0.479234 +[01:41:32.515] iteration 80000 [729.41 sec]: learning rate : 0.000003 loss : 0.794919 +[01:41:32.673] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_80000.pth +[01:42:58.918] iteration 80100 [815.81 sec]: learning rate : 0.000006 loss : 0.767010 +[01:44:25.179] iteration 80200 [902.07 sec]: learning rate : 0.000006 loss : 0.610945 +[01:45:51.381] iteration 80300 [988.27 sec]: learning rate : 0.000006 loss : 0.489648 +[01:47:17.648] iteration 80400 [1074.54 sec]: learning rate : 0.000006 loss : 0.579124 +[01:48:43.871] iteration 80500 [1160.76 sec]: learning rate : 0.000006 loss : 0.788408 +[01:50:10.151] iteration 80600 [1247.04 sec]: learning rate : 0.000006 loss : 0.433448 +[01:51:36.429] iteration 80700 [1333.32 sec]: learning rate : 0.000006 loss : 0.385736 +[01:53:02.645] iteration 80800 [1419.54 sec]: learning rate : 0.000006 loss : 0.342118 +[01:54:28.924] iteration 80900 [1505.82 sec]: learning rate : 0.000006 loss : 0.544238 +[01:55:55.194] iteration 81000 [1592.09 sec]: learning rate : 0.000006 loss : 1.047231 +[01:57:21.420] iteration 81100 [1678.31 sec]: learning rate : 0.000006 loss : 0.521709 +[01:58:47.711] iteration 81200 [1764.60 sec]: learning rate : 0.000006 loss : 0.236921 +[01:59:19.596] Epoch 38 Evaluation: +[02:00:09.040] average MSE: 0.05565338954329491 average PSNR: 28.168759894748224 average SSIM: 0.6015153044353531 +[02:01:03.684] iteration 81300 [54.58 sec]: learning rate : 0.000006 loss : 0.452919 +[02:02:29.859] iteration 81400 [140.76 sec]: learning rate : 0.000006 loss : 0.580437 +[02:03:56.038] iteration 81500 [226.94 sec]: learning rate : 0.000006 loss : 0.984150 +[02:05:22.295] iteration 81600 [313.19 sec]: learning rate : 0.000006 loss : 0.665704 +[02:06:48.565] iteration 81700 [399.46 sec]: learning rate : 0.000006 loss : 0.774587 +[02:08:14.899] iteration 81800 [485.80 sec]: learning rate : 0.000006 loss : 0.575763 +[02:09:41.182] iteration 81900 [572.08 sec]: learning rate : 0.000006 loss : 0.367750 +[02:11:07.514] iteration 82000 [658.41 sec]: learning rate : 0.000006 loss : 0.646751 +[02:12:33.842] iteration 82100 [744.74 sec]: learning rate : 0.000006 loss : 0.731308 +[02:14:00.103] iteration 82200 [831.00 sec]: learning rate : 0.000006 loss : 0.440924 +[02:15:26.429] iteration 82300 [917.33 sec]: learning rate : 0.000006 loss : 0.495068 +[02:16:52.760] iteration 82400 [1003.66 sec]: learning rate : 0.000006 loss : 0.618746 +[02:18:19.058] iteration 82500 [1089.96 sec]: learning rate : 0.000006 loss : 0.711019 +[02:19:45.410] iteration 82600 [1176.32 sec]: learning rate : 0.000006 loss : 0.462918 +[02:21:11.776] iteration 82700 [1262.67 sec]: learning rate : 0.000006 loss : 0.642426 +[02:22:38.034] iteration 82800 [1348.93 sec]: learning rate : 0.000006 loss : 0.561678 +[02:24:04.379] iteration 82900 [1435.28 sec]: learning rate : 0.000006 loss : 0.411135 +[02:25:30.700] iteration 83000 [1521.60 sec]: learning rate : 0.000006 loss : 0.499035 +[02:26:57.008] iteration 83100 [1607.91 sec]: learning rate : 0.000006 loss : 0.353987 +[02:28:23.369] iteration 83200 [1694.27 sec]: learning rate : 0.000006 loss : 0.428507 +[02:29:49.677] iteration 83300 [1780.57 sec]: learning rate : 0.000006 loss : 0.568716 +[02:30:06.915] Epoch 39 Evaluation: +[02:30:56.544] average MSE: 0.055457353591918945 average PSNR: 28.189342297065014 average SSIM: 0.6013374912259398 +[02:32:05.966] iteration 83400 [69.36 sec]: learning rate : 0.000006 loss : 0.660227 +[02:33:32.314] iteration 83500 [155.71 sec]: learning rate : 0.000006 loss : 0.795670 +[02:34:58.623] iteration 83600 [242.02 sec]: learning rate : 0.000006 loss : 0.652668 +[02:36:24.941] iteration 83700 [328.33 sec]: learning rate : 0.000006 loss : 0.859137 +[02:37:51.231] iteration 83800 [414.62 sec]: learning rate : 0.000006 loss : 0.405168 +[02:39:17.568] iteration 83900 [500.96 sec]: learning rate : 0.000006 loss : 0.420888 +[02:40:43.947] iteration 84000 [587.34 sec]: learning rate : 0.000006 loss : 0.561474 +[02:42:10.275] iteration 84100 [673.67 sec]: learning rate : 0.000006 loss : 0.427340 +[02:43:36.651] iteration 84200 [760.04 sec]: learning rate : 0.000006 loss : 0.599593 +[02:45:03.022] iteration 84300 [846.42 sec]: learning rate : 0.000006 loss : 0.585554 +[02:46:29.364] iteration 84400 [932.76 sec]: learning rate : 0.000006 loss : 0.628092 +[02:47:55.743] iteration 84500 [1019.14 sec]: learning rate : 0.000006 loss : 0.690272 +[02:49:22.084] iteration 84600 [1105.48 sec]: learning rate : 0.000006 loss : 0.448883 +[02:50:48.381] iteration 84700 [1191.77 sec]: learning rate : 0.000006 loss : 0.548638 +[02:52:14.741] iteration 84800 [1278.13 sec]: learning rate : 0.000006 loss : 0.405843 +[02:53:41.142] iteration 84900 [1364.53 sec]: learning rate : 0.000006 loss : 0.686892 +[02:55:07.485] iteration 85000 [1450.88 sec]: learning rate : 0.000006 loss : 0.701207 +[02:56:33.885] iteration 85100 [1537.28 sec]: learning rate : 0.000006 loss : 0.430364 +[02:58:00.220] iteration 85200 [1623.61 sec]: learning rate : 0.000006 loss : 0.662351 +[02:59:26.617] iteration 85300 [1710.01 sec]: learning rate : 0.000006 loss : 1.135294 +[03:00:53.027] iteration 85400 [1796.42 sec]: learning rate : 0.000006 loss : 0.652482 +[03:00:55.595] Epoch 40 Evaluation: +[03:01:44.901] average MSE: 0.05554268881678581 average PSNR: 28.177056300051174 average SSIM: 0.6009144145588156 +[03:03:08.834] iteration 85500 [83.87 sec]: learning rate : 0.000006 loss : 0.594242 +[03:04:35.205] iteration 85600 [170.24 sec]: learning rate : 0.000006 loss : 0.488859 +[03:06:01.523] iteration 85700 [256.56 sec]: learning rate : 0.000006 loss : 0.427697 +[03:07:27.861] iteration 85800 [342.90 sec]: learning rate : 0.000006 loss : 0.383633 +[03:08:54.224] iteration 85900 [429.26 sec]: learning rate : 0.000006 loss : 0.872830 +[03:10:20.565] iteration 86000 [515.60 sec]: learning rate : 0.000006 loss : 0.695347 +[03:11:46.960] iteration 86100 [602.00 sec]: learning rate : 0.000006 loss : 0.538381 +[03:13:13.309] iteration 86200 [688.35 sec]: learning rate : 0.000006 loss : 0.610429 +[03:14:39.597] iteration 86300 [774.63 sec]: learning rate : 0.000006 loss : 0.781970 +[03:16:05.925] iteration 86400 [860.96 sec]: learning rate : 0.000006 loss : 0.317976 +[03:17:32.289] iteration 86500 [947.33 sec]: learning rate : 0.000006 loss : 0.675704 +[03:18:58.661] iteration 86600 [1033.70 sec]: learning rate : 0.000006 loss : 0.479122 +[03:20:25.030] iteration 86700 [1120.07 sec]: learning rate : 0.000006 loss : 1.039813 +[03:21:51.397] iteration 86800 [1206.43 sec]: learning rate : 0.000006 loss : 0.647531 +[03:23:17.825] iteration 86900 [1292.86 sec]: learning rate : 0.000006 loss : 0.674854 +[03:24:44.172] iteration 87000 [1379.21 sec]: learning rate : 0.000006 loss : 0.723163 +[03:26:10.451] iteration 87100 [1465.49 sec]: learning rate : 0.000006 loss : 0.539380 +[03:27:36.773] iteration 87200 [1551.81 sec]: learning rate : 0.000006 loss : 0.543552 +[03:29:03.099] iteration 87300 [1638.13 sec]: learning rate : 0.000006 loss : 0.474267 +[03:30:29.358] iteration 87400 [1724.39 sec]: learning rate : 0.000006 loss : 0.444141 +[03:31:43.521] Epoch 41 Evaluation: +[03:32:35.244] average MSE: 0.0554439015686512 average PSNR: 28.185553562388463 average SSIM: 0.6011600358742529 +[03:32:47.593] iteration 87500 [12.29 sec]: learning rate : 0.000006 loss : 0.580708 +[03:34:14.052] iteration 87600 [98.75 sec]: learning rate : 0.000006 loss : 0.655151 +[03:35:40.356] iteration 87700 [185.05 sec]: learning rate : 0.000006 loss : 0.307026 +[03:37:06.720] iteration 87800 [271.41 sec]: learning rate : 0.000006 loss : 0.600674 +[03:38:33.085] iteration 87900 [357.78 sec]: learning rate : 0.000006 loss : 0.667549 +[03:39:59.397] iteration 88000 [444.09 sec]: learning rate : 0.000006 loss : 1.188715 +[03:41:25.714] iteration 88100 [530.41 sec]: learning rate : 0.000006 loss : 0.471698 +[03:42:52.043] iteration 88200 [616.74 sec]: learning rate : 0.000006 loss : 0.280871 +[03:44:18.419] iteration 88300 [703.11 sec]: learning rate : 0.000006 loss : 0.451987 +[03:45:44.786] iteration 88400 [789.48 sec]: learning rate : 0.000006 loss : 0.528314 +[03:47:11.096] iteration 88500 [875.79 sec]: learning rate : 0.000006 loss : 0.368032 +[03:48:37.504] iteration 88600 [962.20 sec]: learning rate : 0.000006 loss : 0.629716 +[03:50:03.887] iteration 88700 [1048.58 sec]: learning rate : 0.000006 loss : 0.589943 +[03:51:30.213] iteration 88800 [1134.91 sec]: learning rate : 0.000006 loss : 1.011075 +[03:52:56.588] iteration 88900 [1221.28 sec]: learning rate : 0.000006 loss : 0.322541 +[03:54:22.847] iteration 89000 [1307.54 sec]: learning rate : 0.000006 loss : 0.602014 +[03:55:49.157] iteration 89100 [1393.85 sec]: learning rate : 0.000006 loss : 0.578843 +[03:57:15.491] iteration 89200 [1480.18 sec]: learning rate : 0.000006 loss : 0.719934 +[03:58:41.795] iteration 89300 [1566.49 sec]: learning rate : 0.000006 loss : 0.786101 +[04:00:08.175] iteration 89400 [1652.87 sec]: learning rate : 0.000006 loss : 0.654458 +[04:01:34.510] iteration 89500 [1739.20 sec]: learning rate : 0.000006 loss : 0.607560 +[04:02:34.036] Epoch 42 Evaluation: +[04:03:24.377] average MSE: 0.055571407079696655 average PSNR: 28.17983615502294 average SSIM: 0.6012369312764064 +[04:03:51.380] iteration 89600 [26.94 sec]: learning rate : 0.000006 loss : 0.726017 +[04:05:17.728] iteration 89700 [113.29 sec]: learning rate : 0.000006 loss : 0.681758 +[04:06:44.066] iteration 89800 [199.62 sec]: learning rate : 0.000006 loss : 1.011563 +[04:08:10.349] iteration 89900 [285.91 sec]: learning rate : 0.000006 loss : 0.695841 +[04:09:36.683] iteration 90000 [372.24 sec]: learning rate : 0.000006 loss : 0.644906 +[04:11:03.025] iteration 90100 [458.58 sec]: learning rate : 0.000006 loss : 0.429014 +[04:12:29.309] iteration 90200 [544.87 sec]: learning rate : 0.000006 loss : 0.730108 +[04:13:55.626] iteration 90300 [631.19 sec]: learning rate : 0.000006 loss : 0.438007 +[04:15:21.905] iteration 90400 [717.47 sec]: learning rate : 0.000006 loss : 1.155092 +[04:16:48.181] iteration 90500 [803.74 sec]: learning rate : 0.000006 loss : 0.501558 +[04:18:14.505] iteration 90600 [890.06 sec]: learning rate : 0.000006 loss : 0.268092 +[04:19:40.810] iteration 90700 [976.37 sec]: learning rate : 0.000006 loss : 0.482601 +[04:21:07.125] iteration 90800 [1062.69 sec]: learning rate : 0.000006 loss : 0.548358 +[04:22:33.497] iteration 90900 [1149.06 sec]: learning rate : 0.000006 loss : 0.922237 +[04:23:59.781] iteration 91000 [1235.34 sec]: learning rate : 0.000006 loss : 0.538255 +[04:25:26.126] iteration 91100 [1321.69 sec]: learning rate : 0.000006 loss : 0.325199 +[04:26:52.474] iteration 91200 [1408.04 sec]: learning rate : 0.000006 loss : 0.759379 +[04:28:18.759] iteration 91300 [1494.32 sec]: learning rate : 0.000006 loss : 0.445293 +[04:29:45.094] iteration 91400 [1580.65 sec]: learning rate : 0.000006 loss : 0.664429 +[04:31:11.392] iteration 91500 [1666.95 sec]: learning rate : 0.000006 loss : 0.467771 +[04:32:37.643] iteration 91600 [1753.20 sec]: learning rate : 0.000006 loss : 0.439004 +[04:33:22.448] Epoch 43 Evaluation: +[04:34:13.801] average MSE: 0.055531229823827744 average PSNR: 28.178174973351688 average SSIM: 0.601264085359837 +[04:34:55.453] iteration 91700 [41.59 sec]: learning rate : 0.000006 loss : 0.308554 +[04:36:21.828] iteration 91800 [127.96 sec]: learning rate : 0.000006 loss : 0.455694 +[04:37:48.078] iteration 91900 [214.21 sec]: learning rate : 0.000006 loss : 0.871803 +[04:39:14.416] iteration 92000 [300.55 sec]: learning rate : 0.000006 loss : 0.728988 +[04:40:40.728] iteration 92100 [386.86 sec]: learning rate : 0.000006 loss : 0.544013 +[04:42:06.982] iteration 92200 [473.12 sec]: learning rate : 0.000006 loss : 0.353199 +[04:43:33.290] iteration 92300 [559.43 sec]: learning rate : 0.000006 loss : 0.607868 +[04:44:59.617] iteration 92400 [645.76 sec]: learning rate : 0.000006 loss : 0.589910 +[04:46:25.895] iteration 92500 [732.03 sec]: learning rate : 0.000006 loss : 0.697827 +[04:47:52.188] iteration 92600 [818.32 sec]: learning rate : 0.000006 loss : 0.541316 +[04:49:18.443] iteration 92700 [904.58 sec]: learning rate : 0.000006 loss : 0.683175 +[04:50:44.781] iteration 92800 [990.92 sec]: learning rate : 0.000006 loss : 0.712781 +[04:52:11.112] iteration 92900 [1077.25 sec]: learning rate : 0.000006 loss : 0.421327 +[04:53:37.392] iteration 93000 [1163.53 sec]: learning rate : 0.000006 loss : 0.478027 +[04:55:03.707] iteration 93100 [1249.84 sec]: learning rate : 0.000006 loss : 0.585550 +[04:56:30.044] iteration 93200 [1336.18 sec]: learning rate : 0.000006 loss : 0.420877 +[04:57:56.322] iteration 93300 [1422.46 sec]: learning rate : 0.000006 loss : 0.469968 +[04:59:22.625] iteration 93400 [1508.76 sec]: learning rate : 0.000006 loss : 0.680499 +[05:00:48.938] iteration 93500 [1595.07 sec]: learning rate : 0.000006 loss : 0.753634 +[05:02:15.219] iteration 93600 [1681.36 sec]: learning rate : 0.000006 loss : 0.593275 +[05:03:41.575] iteration 93700 [1767.71 sec]: learning rate : 0.000006 loss : 0.470756 +[05:04:11.752] Epoch 44 Evaluation: +[05:05:03.146] average MSE: 0.05560979247093201 average PSNR: 28.17513587267052 average SSIM: 0.6013936445141139 +[05:05:59.438] iteration 93800 [56.23 sec]: learning rate : 0.000006 loss : 0.879215 +[05:07:25.761] iteration 93900 [142.55 sec]: learning rate : 0.000006 loss : 0.645807 +[05:08:52.031] iteration 94000 [228.82 sec]: learning rate : 0.000006 loss : 0.432033 +[05:10:18.289] iteration 94100 [315.08 sec]: learning rate : 0.000006 loss : 0.435162 +[05:11:44.641] iteration 94200 [401.43 sec]: learning rate : 0.000006 loss : 0.491421 +[05:13:10.984] iteration 94300 [487.78 sec]: learning rate : 0.000006 loss : 0.781246 +[05:14:37.211] iteration 94400 [574.00 sec]: learning rate : 0.000006 loss : 0.372182 +[05:16:03.497] iteration 94500 [660.29 sec]: learning rate : 0.000006 loss : 0.747162 +[05:17:29.848] iteration 94600 [746.64 sec]: learning rate : 0.000006 loss : 0.735436 +[05:18:56.118] iteration 94700 [832.91 sec]: learning rate : 0.000006 loss : 0.844715 +[05:20:22.418] iteration 94800 [919.21 sec]: learning rate : 0.000006 loss : 0.444984 +[05:21:48.691] iteration 94900 [1005.48 sec]: learning rate : 0.000006 loss : 0.401200 +[05:23:14.924] iteration 95000 [1091.72 sec]: learning rate : 0.000006 loss : 1.048200 +[05:24:41.252] iteration 95100 [1178.04 sec]: learning rate : 0.000006 loss : 0.329303 +[05:26:07.550] iteration 95200 [1264.34 sec]: learning rate : 0.000006 loss : 0.665456 +[05:27:33.913] iteration 95300 [1350.71 sec]: learning rate : 0.000006 loss : 0.626139 +[05:29:00.260] iteration 95400 [1437.05 sec]: learning rate : 0.000006 loss : 0.478088 +[05:30:26.576] iteration 95500 [1523.37 sec]: learning rate : 0.000006 loss : 0.411794 +[05:31:52.948] iteration 95600 [1609.74 sec]: learning rate : 0.000006 loss : 0.380993 +[05:33:19.318] iteration 95700 [1696.11 sec]: learning rate : 0.000006 loss : 0.714422 +[05:34:45.631] iteration 95800 [1782.42 sec]: learning rate : 0.000006 loss : 0.581500 +[05:35:01.138] Epoch 45 Evaluation: +[05:35:52.902] average MSE: 0.055421844124794006 average PSNR: 28.191043524249242 average SSIM: 0.6018179795707127 +[05:37:04.015] iteration 95900 [71.05 sec]: learning rate : 0.000006 loss : 0.503636 +[05:38:30.373] iteration 96000 [157.41 sec]: learning rate : 0.000006 loss : 0.518686 +[05:39:56.698] iteration 96100 [243.73 sec]: learning rate : 0.000006 loss : 0.505330 +[05:41:23.094] iteration 96200 [330.13 sec]: learning rate : 0.000006 loss : 0.436827 +[05:42:49.410] iteration 96300 [416.45 sec]: learning rate : 0.000006 loss : 0.350203 +[05:44:15.799] iteration 96400 [502.84 sec]: learning rate : 0.000006 loss : 0.614168 +[05:45:42.149] iteration 96500 [589.19 sec]: learning rate : 0.000006 loss : 0.886219 +[05:47:08.438] iteration 96600 [675.47 sec]: learning rate : 0.000006 loss : 0.869219 +[05:48:34.751] iteration 96700 [761.79 sec]: learning rate : 0.000006 loss : 0.475517 +[05:50:01.095] iteration 96800 [848.13 sec]: learning rate : 0.000006 loss : 0.537170 +[05:51:27.363] iteration 96900 [934.40 sec]: learning rate : 0.000006 loss : 0.482917 +[05:52:53.651] iteration 97000 [1020.69 sec]: learning rate : 0.000006 loss : 0.618592 +[05:54:19.968] iteration 97100 [1107.00 sec]: learning rate : 0.000006 loss : 0.936889 +[05:55:46.263] iteration 97200 [1193.30 sec]: learning rate : 0.000006 loss : 0.738020 +[05:57:12.610] iteration 97300 [1279.65 sec]: learning rate : 0.000006 loss : 0.795859 +[05:58:38.917] iteration 97400 [1365.96 sec]: learning rate : 0.000006 loss : 0.524405 +[06:00:05.239] iteration 97500 [1452.28 sec]: learning rate : 0.000006 loss : 0.680924 +[06:01:31.583] iteration 97600 [1538.62 sec]: learning rate : 0.000006 loss : 0.563274 +[06:02:57.919] iteration 97700 [1624.96 sec]: learning rate : 0.000006 loss : 0.590113 +[06:04:24.268] iteration 97800 [1711.30 sec]: learning rate : 0.000006 loss : 0.673282 +[06:05:50.649] iteration 97900 [1797.69 sec]: learning rate : 0.000006 loss : 0.664278 +[06:05:51.499] Epoch 46 Evaluation: +[06:06:42.228] average MSE: 0.055568892508745193 average PSNR: 28.179945355907552 average SSIM: 0.6018257775205622 +[06:08:07.898] iteration 98000 [85.61 sec]: learning rate : 0.000006 loss : 0.599979 +[06:09:34.264] iteration 98100 [171.97 sec]: learning rate : 0.000006 loss : 0.717532 +[06:11:00.626] iteration 98200 [258.33 sec]: learning rate : 0.000006 loss : 0.423055 +[06:12:26.895] iteration 98300 [344.60 sec]: learning rate : 0.000006 loss : 0.612318 +[06:13:53.261] iteration 98400 [430.97 sec]: learning rate : 0.000006 loss : 0.639589 +[06:15:19.682] iteration 98500 [517.39 sec]: learning rate : 0.000006 loss : 0.677395 +[06:16:45.949] iteration 98600 [603.66 sec]: learning rate : 0.000006 loss : 0.437254 +[06:18:12.343] iteration 98700 [690.05 sec]: learning rate : 0.000006 loss : 0.652701 +[06:19:38.736] iteration 98800 [776.44 sec]: learning rate : 0.000006 loss : 0.574044 +[06:21:05.045] iteration 98900 [862.75 sec]: learning rate : 0.000006 loss : 0.460565 +[06:22:31.379] iteration 99000 [949.09 sec]: learning rate : 0.000006 loss : 0.633401 +[06:23:57.717] iteration 99100 [1035.43 sec]: learning rate : 0.000006 loss : 0.602550 +[06:25:24.012] iteration 99200 [1121.72 sec]: learning rate : 0.000006 loss : 0.644740 +[06:26:50.411] iteration 99300 [1208.12 sec]: learning rate : 0.000006 loss : 0.767302 +[06:28:16.742] iteration 99400 [1294.45 sec]: learning rate : 0.000006 loss : 0.593010 +[06:29:43.095] iteration 99500 [1380.80 sec]: learning rate : 0.000006 loss : 0.562752 +[06:31:09.450] iteration 99600 [1467.16 sec]: learning rate : 0.000006 loss : 0.467983 +[06:32:35.736] iteration 99700 [1553.44 sec]: learning rate : 0.000006 loss : 0.690522 +[06:34:02.051] iteration 99800 [1639.76 sec]: learning rate : 0.000006 loss : 0.520895 +[06:35:28.434] iteration 99900 [1726.14 sec]: learning rate : 0.000006 loss : 0.476899 +[06:36:40.944] Epoch 47 Evaluation: +[06:37:30.107] average MSE: 0.055474668741226196 average PSNR: 28.188499955033183 average SSIM: 0.6018062841858316 +[06:37:44.145] iteration 100000 [13.98 sec]: learning rate : 0.000002 loss : 0.724326 +[06:37:44.308] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_100000.pth +[06:37:45.143] Epoch 48 Evaluation: +[06:38:36.210] average MSE: 0.05550547316670418 average PSNR: 28.18536053775759 average SSIM: 0.6015941183409935 +[06:38:36.477] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_100000.pth +===> Evaluate Metric <=== +Direct Results +------------------------------------ +NMSE: 4.6822 ± 0.9425 +PSNR: 28.0308 ± 1.5644 +SSIM: 0.5819 ± 0.0607 +------------------------------------ +===> Evaluate Metric <=== +Results +------------------------------------ +NMSE: 4.6787 ± 0.9861 +PSNR: 28.0399 ± 1.5972 +SSIM: 0.5896 ± 0.0585 +------------------------------------ \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_/log/events.out.tfevents.1752411517.GCRSANDBOX133 b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_/log/events.out.tfevents.1752411517.GCRSANDBOX133 new file mode 100644 index 0000000000000000000000000000000000000000..dfb7075ac364e36834c93ec20e9dcddb4278b347 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_/log/events.out.tfevents.1752411517.GCRSANDBOX133 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:522c945436b418c31cf7fe0d952caf514c109e102cf8986c976fe13585434f73 +size 40 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion/best_checkpoint.pth b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion/best_checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..1a05db3b0a9a0ca9856fc70aef731143272ad7b1 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion/best_checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0286e506df8ba1f61b0bbb0c967b2cff38c7285e28f5005ae55691b5606ac4d6 +size 56614874 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion/log.txt b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..e5b910a0e86a4164f91e718fb16a17438268316e --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion/log.txt @@ -0,0 +1,1105 @@ +[23:41:33.087] Namespace(root_path='/home/v-qichen3/MRI_recon/data/fastmri', MRIDOWN='8X', low_field_SNR=0, phase='train', gpu='3', exp='FSMNet_fastmri_8x', max_iterations=100000, batch_size=4, base_lr=0.0001, seed=1337, resume=None, relation_consistency='False', clip_grad='True', norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, snapshot_path='None', rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='equispaced', CENTER_FRACTIONS=[0.04], ACCELERATIONS=[8], num_timesteps=5, image_size=320, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[23:43:02.800] iteration 100 [87.06 sec]: learning rate : 0.000100 loss : 0.686934 +[23:44:28.952] iteration 200 [173.20 sec]: learning rate : 0.000100 loss : 0.806370 +[23:45:55.210] iteration 300 [259.45 sec]: learning rate : 0.000100 loss : 0.933832 +[23:47:21.496] iteration 400 [345.74 sec]: learning rate : 0.000100 loss : 0.744205 +[23:48:47.769] iteration 500 [432.01 sec]: learning rate : 0.000100 loss : 0.411553 +[23:50:14.117] iteration 600 [518.36 sec]: learning rate : 0.000100 loss : 0.693185 +[23:51:40.511] iteration 700 [604.75 sec]: learning rate : 0.000100 loss : 0.521932 +[23:53:06.831] iteration 800 [691.07 sec]: learning rate : 0.000100 loss : 0.887903 +[23:54:33.241] iteration 900 [777.48 sec]: learning rate : 0.000100 loss : 0.585629 +[23:55:59.717] iteration 1000 [863.96 sec]: learning rate : 0.000100 loss : 0.487472 +[23:57:26.105] iteration 1100 [950.35 sec]: learning rate : 0.000100 loss : 0.971876 +[23:58:52.507] iteration 1200 [1036.75 sec]: learning rate : 0.000100 loss : 0.956802 +[00:00:18.927] iteration 1300 [1123.17 sec]: learning rate : 0.000100 loss : 0.759174 +[00:01:45.431] iteration 1400 [1209.67 sec]: learning rate : 0.000100 loss : 0.921390 +[00:03:11.861] iteration 1500 [1296.10 sec]: learning rate : 0.000100 loss : 1.408719 +[00:04:38.290] iteration 1600 [1382.53 sec]: learning rate : 0.000100 loss : 0.993416 +[00:06:04.780] iteration 1700 [1469.02 sec]: learning rate : 0.000100 loss : 0.824427 +[00:07:31.220] iteration 1800 [1555.46 sec]: learning rate : 0.000100 loss : 0.917329 +[00:08:57.703] iteration 1900 [1641.95 sec]: learning rate : 0.000100 loss : 1.050593 +[00:10:24.255] iteration 2000 [1728.50 sec]: learning rate : 0.000100 loss : 0.585130 +[00:11:36.041] Epoch 0 Evaluation: +[00:12:26.913] average MSE: 0.06771890074014664 average PSNR: 27.030683418933837 average SSIM: 0.5511461229033533 +[00:12:41.875] iteration 2100 [14.90 sec]: learning rate : 0.000100 loss : 0.503566 +[00:14:08.449] iteration 2200 [101.47 sec]: learning rate : 0.000100 loss : 0.832337 +[00:15:35.002] iteration 2300 [188.02 sec]: learning rate : 0.000100 loss : 0.830108 +[00:17:01.500] iteration 2400 [274.52 sec]: learning rate : 0.000100 loss : 0.903948 +[00:18:28.063] iteration 2500 [361.09 sec]: learning rate : 0.000100 loss : 0.883680 +[00:19:54.579] iteration 2600 [447.60 sec]: learning rate : 0.000100 loss : 0.901029 +[00:21:21.178] iteration 2700 [534.20 sec]: learning rate : 0.000100 loss : 1.106143 +[00:22:47.739] iteration 2800 [620.76 sec]: learning rate : 0.000100 loss : 0.912510 +[00:24:14.242] iteration 2900 [707.27 sec]: learning rate : 0.000100 loss : 0.768687 +[00:25:40.819] iteration 3000 [793.84 sec]: learning rate : 0.000100 loss : 0.600804 +[00:27:07.367] iteration 3100 [880.39 sec]: learning rate : 0.000100 loss : 0.851269 +[00:28:33.853] iteration 3200 [966.88 sec]: learning rate : 0.000100 loss : 0.843562 +[00:30:00.414] iteration 3300 [1053.44 sec]: learning rate : 0.000100 loss : 0.575664 +[00:31:26.949] iteration 3400 [1139.97 sec]: learning rate : 0.000100 loss : 0.800741 +[00:32:53.445] iteration 3500 [1226.47 sec]: learning rate : 0.000100 loss : 0.603604 +[00:34:19.985] iteration 3600 [1313.01 sec]: learning rate : 0.000100 loss : 0.713947 +[00:35:46.556] iteration 3700 [1399.64 sec]: learning rate : 0.000100 loss : 0.732115 +[00:37:13.026] iteration 3800 [1486.05 sec]: learning rate : 0.000100 loss : 0.737049 +[00:38:39.543] iteration 3900 [1572.57 sec]: learning rate : 0.000100 loss : 0.642374 +[00:40:06.055] iteration 4000 [1659.08 sec]: learning rate : 0.000100 loss : 0.630302 +[00:41:32.457] iteration 4100 [1745.48 sec]: learning rate : 0.000100 loss : 1.070170 +[00:42:29.569] Epoch 1 Evaluation: +[00:43:20.318] average MSE: 0.06740959733724594 average PSNR: 27.05194164646122 average SSIM: 0.5629047112698746 +[00:43:49.954] iteration 4200 [29.57 sec]: learning rate : 0.000100 loss : 0.662025 +[00:45:16.365] iteration 4300 [115.98 sec]: learning rate : 0.000100 loss : 0.685902 +[00:46:42.912] iteration 4400 [202.53 sec]: learning rate : 0.000100 loss : 0.584809 +[00:48:09.379] iteration 4500 [289.00 sec]: learning rate : 0.000100 loss : 0.579475 +[00:49:35.778] iteration 4600 [375.40 sec]: learning rate : 0.000100 loss : 0.502147 +[00:51:02.271] iteration 4700 [461.89 sec]: learning rate : 0.000100 loss : 1.291847 +[00:52:28.729] iteration 4800 [548.35 sec]: learning rate : 0.000100 loss : 0.648725 +[00:53:55.119] iteration 4900 [634.74 sec]: learning rate : 0.000100 loss : 0.612523 +[00:55:21.583] iteration 5000 [721.20 sec]: learning rate : 0.000100 loss : 0.587966 +[00:56:48.022] iteration 5100 [807.64 sec]: learning rate : 0.000100 loss : 0.620993 +[00:58:14.384] iteration 5200 [894.00 sec]: learning rate : 0.000100 loss : 0.515293 +[00:59:40.809] iteration 5300 [980.43 sec]: learning rate : 0.000100 loss : 0.674140 +[01:01:07.242] iteration 5400 [1066.86 sec]: learning rate : 0.000100 loss : 0.806371 +[01:02:33.609] iteration 5500 [1153.23 sec]: learning rate : 0.000100 loss : 0.697128 +[01:04:00.048] iteration 5600 [1239.67 sec]: learning rate : 0.000100 loss : 0.554092 +[01:05:26.409] iteration 5700 [1326.03 sec]: learning rate : 0.000100 loss : 0.738021 +[01:06:52.834] iteration 5800 [1412.45 sec]: learning rate : 0.000100 loss : 0.930301 +[01:08:19.216] iteration 5900 [1498.83 sec]: learning rate : 0.000100 loss : 0.535885 +[01:09:45.527] iteration 6000 [1585.15 sec]: learning rate : 0.000100 loss : 0.560490 +[01:11:11.891] iteration 6100 [1671.51 sec]: learning rate : 0.000100 loss : 0.570412 +[01:12:38.312] iteration 6200 [1757.93 sec]: learning rate : 0.000100 loss : 0.841589 +[01:13:20.579] Epoch 2 Evaluation: +[01:14:10.132] average MSE: 0.06212643161416054 average PSNR: 27.47101323736568 average SSIM: 0.5752163053925181 +[01:14:54.382] iteration 6300 [44.19 sec]: learning rate : 0.000100 loss : 0.493597 +[01:16:20.716] iteration 6400 [130.52 sec]: learning rate : 0.000100 loss : 0.800454 +[01:17:47.046] iteration 6500 [216.85 sec]: learning rate : 0.000100 loss : 0.330391 +[01:19:13.341] iteration 6600 [303.15 sec]: learning rate : 0.000100 loss : 0.730794 +[01:20:39.709] iteration 6700 [389.51 sec]: learning rate : 0.000100 loss : 0.671578 +[01:22:06.080] iteration 6800 [475.89 sec]: learning rate : 0.000100 loss : 0.553221 +[01:23:32.376] iteration 6900 [562.18 sec]: learning rate : 0.000100 loss : 0.554918 +[01:24:58.695] iteration 7000 [648.50 sec]: learning rate : 0.000100 loss : 0.752271 +[01:26:24.999] iteration 7100 [734.80 sec]: learning rate : 0.000100 loss : 0.785115 +[01:27:51.365] iteration 7200 [821.17 sec]: learning rate : 0.000100 loss : 0.508208 +[01:29:17.691] iteration 7300 [907.50 sec]: learning rate : 0.000100 loss : 0.668182 +[01:30:43.942] iteration 7400 [993.75 sec]: learning rate : 0.000100 loss : 0.396943 +[01:32:10.261] iteration 7500 [1080.07 sec]: learning rate : 0.000100 loss : 0.728511 +[01:33:36.602] iteration 7600 [1166.41 sec]: learning rate : 0.000100 loss : 0.612895 +[01:35:02.930] iteration 7700 [1252.75 sec]: learning rate : 0.000100 loss : 0.474731 +[01:36:29.235] iteration 7800 [1339.04 sec]: learning rate : 0.000100 loss : 0.853154 +[01:37:55.555] iteration 7900 [1425.36 sec]: learning rate : 0.000100 loss : 0.670141 +[01:39:21.828] iteration 8000 [1511.63 sec]: learning rate : 0.000100 loss : 0.549733 +[01:40:48.142] iteration 8100 [1597.95 sec]: learning rate : 0.000100 loss : 0.552433 +[01:42:14.475] iteration 8200 [1684.28 sec]: learning rate : 0.000100 loss : 0.778704 +[01:43:40.719] iteration 8300 [1770.52 sec]: learning rate : 0.000100 loss : 0.606724 +[01:44:08.376] Epoch 3 Evaluation: +[01:44:59.630] average MSE: 0.061993557959795 average PSNR: 27.507615497157058 average SSIM: 0.5778273946850346 +[01:45:58.538] iteration 8400 [58.85 sec]: learning rate : 0.000100 loss : 0.812536 +[01:47:24.888] iteration 8500 [145.19 sec]: learning rate : 0.000100 loss : 0.364732 +[01:48:51.131] iteration 8600 [231.44 sec]: learning rate : 0.000100 loss : 0.722282 +[01:50:17.424] iteration 8700 [317.73 sec]: learning rate : 0.000100 loss : 0.667588 +[01:51:43.719] iteration 8800 [404.03 sec]: learning rate : 0.000100 loss : 0.652864 +[01:53:09.987] iteration 8900 [490.29 sec]: learning rate : 0.000100 loss : 0.622632 +[01:54:36.292] iteration 9000 [576.60 sec]: learning rate : 0.000100 loss : 0.917782 +[01:56:02.550] iteration 9100 [662.86 sec]: learning rate : 0.000100 loss : 0.533774 +[01:57:28.880] iteration 9200 [749.19 sec]: learning rate : 0.000100 loss : 1.010461 +[01:58:55.174] iteration 9300 [835.48 sec]: learning rate : 0.000100 loss : 0.495178 +[02:00:21.386] iteration 9400 [921.69 sec]: learning rate : 0.000100 loss : 0.677087 +[02:01:47.675] iteration 9500 [1007.98 sec]: learning rate : 0.000100 loss : 0.385740 +[02:03:13.988] iteration 9600 [1094.29 sec]: learning rate : 0.000100 loss : 0.807401 +[02:04:40.207] iteration 9700 [1180.51 sec]: learning rate : 0.000100 loss : 0.625983 +[02:06:06.481] iteration 9800 [1266.79 sec]: learning rate : 0.000100 loss : 0.507432 +[02:07:32.793] iteration 9900 [1353.10 sec]: learning rate : 0.000100 loss : 0.691564 +[02:08:59.040] iteration 10000 [1439.35 sec]: learning rate : 0.000100 loss : 0.808972 +[02:10:25.363] iteration 10100 [1525.67 sec]: learning rate : 0.000100 loss : 0.541645 +[02:11:51.704] iteration 10200 [1612.01 sec]: learning rate : 0.000100 loss : 0.638650 +[02:13:17.966] iteration 10300 [1698.27 sec]: learning rate : 0.000100 loss : 0.962796 +[02:14:44.212] iteration 10400 [1784.52 sec]: learning rate : 0.000100 loss : 0.926998 +[02:14:57.134] Epoch 4 Evaluation: +[02:15:46.539] average MSE: 0.061654653400182724 average PSNR: 27.591149891975373 average SSIM: 0.5856478687763134 +[02:17:00.093] iteration 10500 [73.49 sec]: learning rate : 0.000100 loss : 0.533822 +[02:18:26.440] iteration 10600 [159.84 sec]: learning rate : 0.000100 loss : 0.628890 +[02:19:52.721] iteration 10700 [246.12 sec]: learning rate : 0.000100 loss : 0.622948 +[02:21:18.980] iteration 10800 [332.38 sec]: learning rate : 0.000100 loss : 0.849071 +[02:22:45.331] iteration 10900 [418.73 sec]: learning rate : 0.000100 loss : 0.594749 +[02:24:11.639] iteration 11000 [505.04 sec]: learning rate : 0.000100 loss : 0.681287 +[02:25:37.924] iteration 11100 [591.32 sec]: learning rate : 0.000100 loss : 0.778355 +[02:27:04.253] iteration 11200 [677.65 sec]: learning rate : 0.000100 loss : 0.547874 +[02:28:30.596] iteration 11300 [764.00 sec]: learning rate : 0.000100 loss : 0.428158 +[02:29:56.902] iteration 11400 [850.30 sec]: learning rate : 0.000100 loss : 0.702284 +[02:31:23.246] iteration 11500 [936.64 sec]: learning rate : 0.000100 loss : 0.668227 +[02:32:49.559] iteration 11600 [1022.96 sec]: learning rate : 0.000100 loss : 1.052269 +[02:34:15.901] iteration 11700 [1109.30 sec]: learning rate : 0.000100 loss : 0.515069 +[02:35:42.247] iteration 11800 [1195.65 sec]: learning rate : 0.000100 loss : 0.644975 +[02:37:08.583] iteration 11900 [1281.98 sec]: learning rate : 0.000100 loss : 0.567772 +[02:38:35.005] iteration 12000 [1368.40 sec]: learning rate : 0.000100 loss : 0.647389 +[02:40:01.363] iteration 12100 [1454.76 sec]: learning rate : 0.000100 loss : 0.482086 +[02:41:27.670] iteration 12200 [1541.07 sec]: learning rate : 0.000100 loss : 0.501617 +[02:42:54.041] iteration 12300 [1627.44 sec]: learning rate : 0.000100 loss : 0.542000 +[02:44:20.386] iteration 12400 [1713.78 sec]: learning rate : 0.000100 loss : 0.674330 +[02:45:44.942] Epoch 5 Evaluation: +[02:46:36.126] average MSE: 0.060298506170511246 average PSNR: 27.677833437964182 average SSIM: 0.5847177002896037 +[02:46:38.117] iteration 12500 [1.93 sec]: learning rate : 0.000100 loss : 0.325365 +[02:48:04.523] iteration 12600 [88.34 sec]: learning rate : 0.000100 loss : 0.649197 +[02:49:30.884] iteration 12700 [174.69 sec]: learning rate : 0.000100 loss : 0.656535 +[02:50:57.221] iteration 12800 [261.03 sec]: learning rate : 0.000100 loss : 0.495281 +[02:52:23.659] iteration 12900 [347.47 sec]: learning rate : 0.000100 loss : 0.619288 +[02:53:50.042] iteration 13000 [433.85 sec]: learning rate : 0.000100 loss : 0.425093 +[02:55:16.459] iteration 13100 [520.27 sec]: learning rate : 0.000100 loss : 1.407353 +[02:56:42.901] iteration 13200 [606.71 sec]: learning rate : 0.000100 loss : 0.611240 +[02:58:09.250] iteration 13300 [693.06 sec]: learning rate : 0.000100 loss : 0.598244 +[02:59:35.654] iteration 13400 [779.46 sec]: learning rate : 0.000100 loss : 0.877338 +[03:01:02.063] iteration 13500 [865.87 sec]: learning rate : 0.000100 loss : 0.435730 +[03:02:28.436] iteration 13600 [952.25 sec]: learning rate : 0.000100 loss : 0.694397 +[03:03:54.869] iteration 13700 [1038.68 sec]: learning rate : 0.000100 loss : 0.490954 +[03:05:21.328] iteration 13800 [1125.14 sec]: learning rate : 0.000100 loss : 0.624239 +[03:06:47.704] iteration 13900 [1211.51 sec]: learning rate : 0.000100 loss : 0.622853 +[03:08:14.168] iteration 14000 [1297.98 sec]: learning rate : 0.000100 loss : 0.576034 +[03:09:40.592] iteration 14100 [1384.40 sec]: learning rate : 0.000100 loss : 0.826589 +[03:11:07.011] iteration 14200 [1470.82 sec]: learning rate : 0.000100 loss : 0.730345 +[03:12:33.482] iteration 14300 [1557.29 sec]: learning rate : 0.000100 loss : 0.687232 +[03:13:59.965] iteration 14400 [1643.77 sec]: learning rate : 0.000100 loss : 0.689310 +[03:15:26.391] iteration 14500 [1730.20 sec]: learning rate : 0.000100 loss : 0.478177 +[03:16:36.482] Epoch 6 Evaluation: +[03:17:26.000] average MSE: 0.05908812955021858 average PSNR: 27.790304223645663 average SSIM: 0.5894811533523657 +[03:17:42.704] iteration 14600 [16.63 sec]: learning rate : 0.000100 loss : 0.580761 +[03:19:09.256] iteration 14700 [103.18 sec]: learning rate : 0.000100 loss : 0.708817 +[03:20:35.692] iteration 14800 [189.62 sec]: learning rate : 0.000100 loss : 0.603301 +[03:22:02.229] iteration 14900 [276.16 sec]: learning rate : 0.000100 loss : 0.573178 +[03:23:28.671] iteration 15000 [362.60 sec]: learning rate : 0.000100 loss : 0.679997 +[03:24:55.143] iteration 15100 [449.07 sec]: learning rate : 0.000100 loss : 0.674477 +[03:26:21.629] iteration 15200 [535.56 sec]: learning rate : 0.000100 loss : 0.553699 +[03:27:48.059] iteration 15300 [621.99 sec]: learning rate : 0.000100 loss : 0.633802 +[03:29:14.558] iteration 15400 [708.49 sec]: learning rate : 0.000100 loss : 0.840366 +[03:30:41.074] iteration 15500 [795.00 sec]: learning rate : 0.000100 loss : 0.721321 +[03:32:07.500] iteration 15600 [881.43 sec]: learning rate : 0.000100 loss : 0.311877 +[03:33:33.961] iteration 15700 [967.89 sec]: learning rate : 0.000100 loss : 0.591676 +[03:35:00.390] iteration 15800 [1054.32 sec]: learning rate : 0.000100 loss : 0.730590 +[03:36:26.882] iteration 15900 [1140.81 sec]: learning rate : 0.000100 loss : 0.638604 +[03:37:53.355] iteration 16000 [1227.29 sec]: learning rate : 0.000100 loss : 0.359131 +[03:39:19.783] iteration 16100 [1313.71 sec]: learning rate : 0.000100 loss : 0.464340 +[03:40:46.262] iteration 16200 [1400.19 sec]: learning rate : 0.000100 loss : 0.598050 +[03:42:12.741] iteration 16300 [1486.67 sec]: learning rate : 0.000100 loss : 0.640183 +[03:43:39.160] iteration 16400 [1573.09 sec]: learning rate : 0.000100 loss : 0.419135 +[03:45:05.654] iteration 16500 [1659.59 sec]: learning rate : 0.000100 loss : 0.807252 +[03:46:32.140] iteration 16600 [1746.07 sec]: learning rate : 0.000100 loss : 0.860258 +[03:47:27.436] Epoch 7 Evaluation: +[03:48:17.256] average MSE: 0.05930082127451897 average PSNR: 27.798578401711612 average SSIM: 0.5917902013420258 +[03:48:48.597] iteration 16700 [31.28 sec]: learning rate : 0.000100 loss : 0.897112 +[03:50:15.097] iteration 16800 [117.78 sec]: learning rate : 0.000100 loss : 0.771574 +[03:51:41.560] iteration 16900 [204.24 sec]: learning rate : 0.000100 loss : 0.429683 +[03:53:07.979] iteration 17000 [290.66 sec]: learning rate : 0.000100 loss : 0.786671 +[03:54:34.445] iteration 17100 [377.13 sec]: learning rate : 0.000100 loss : 0.365274 +[03:56:00.842] iteration 17200 [463.52 sec]: learning rate : 0.000100 loss : 0.742959 +[03:57:27.341] iteration 17300 [550.02 sec]: learning rate : 0.000100 loss : 0.546028 +[03:58:53.766] iteration 17400 [636.45 sec]: learning rate : 0.000100 loss : 0.384512 +[04:00:20.185] iteration 17500 [722.86 sec]: learning rate : 0.000100 loss : 0.547299 +[04:01:46.633] iteration 17600 [809.31 sec]: learning rate : 0.000100 loss : 0.761797 +[04:03:13.102] iteration 17700 [895.78 sec]: learning rate : 0.000100 loss : 0.490339 +[04:04:39.508] iteration 17800 [982.19 sec]: learning rate : 0.000100 loss : 0.425512 +[04:06:05.962] iteration 17900 [1068.64 sec]: learning rate : 0.000100 loss : 0.373686 +[04:07:32.427] iteration 18000 [1155.11 sec]: learning rate : 0.000100 loss : 1.113707 +[04:08:58.839] iteration 18100 [1241.52 sec]: learning rate : 0.000100 loss : 0.471286 +[04:10:25.257] iteration 18200 [1327.94 sec]: learning rate : 0.000100 loss : 0.476366 +[04:11:51.666] iteration 18300 [1414.35 sec]: learning rate : 0.000100 loss : 0.703005 +[04:13:18.130] iteration 18400 [1500.81 sec]: learning rate : 0.000100 loss : 0.622033 +[04:14:44.645] iteration 18500 [1587.33 sec]: learning rate : 0.000100 loss : 0.921573 +[04:16:11.032] iteration 18600 [1673.71 sec]: learning rate : 0.000100 loss : 0.556849 +[04:17:37.465] iteration 18700 [1760.15 sec]: learning rate : 0.000100 loss : 0.743781 +[04:18:18.042] Epoch 8 Evaluation: +[04:19:07.865] average MSE: 0.058756958693265915 average PSNR: 27.84747102193612 average SSIM: 0.5941692019046706 +[04:19:53.997] iteration 18800 [46.07 sec]: learning rate : 0.000100 loss : 0.966464 +[04:21:20.357] iteration 18900 [132.43 sec]: learning rate : 0.000100 loss : 0.892108 +[04:22:46.791] iteration 19000 [218.86 sec]: learning rate : 0.000100 loss : 0.561191 +[04:24:13.261] iteration 19100 [305.33 sec]: learning rate : 0.000100 loss : 0.840956 +[04:25:39.630] iteration 19200 [391.70 sec]: learning rate : 0.000100 loss : 0.756529 +[04:27:06.040] iteration 19300 [478.11 sec]: learning rate : 0.000100 loss : 0.429257 +[04:28:32.485] iteration 19400 [564.56 sec]: learning rate : 0.000100 loss : 0.609197 +[04:29:58.852] iteration 19500 [650.92 sec]: learning rate : 0.000100 loss : 0.670285 +[04:31:25.292] iteration 19600 [737.37 sec]: learning rate : 0.000100 loss : 0.758609 +[04:32:51.717] iteration 19700 [823.79 sec]: learning rate : 0.000100 loss : 0.539992 +[04:34:18.045] iteration 19800 [910.12 sec]: learning rate : 0.000100 loss : 0.782600 +[04:35:44.491] iteration 19900 [996.56 sec]: learning rate : 0.000100 loss : 0.532430 +[04:37:10.880] iteration 20000 [1082.95 sec]: learning rate : 0.000025 loss : 0.498853 +[04:37:11.158] save model to model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion/iter_20000.pth +[04:38:37.523] iteration 20100 [1169.60 sec]: learning rate : 0.000050 loss : 0.819929 +[04:40:04.009] iteration 20200 [1256.08 sec]: learning rate : 0.000050 loss : 0.395848 +[04:41:30.362] iteration 20300 [1342.43 sec]: learning rate : 0.000050 loss : 0.835883 +[04:42:56.824] iteration 20400 [1428.90 sec]: learning rate : 0.000050 loss : 0.456192 +[04:44:23.244] iteration 20500 [1515.32 sec]: learning rate : 0.000050 loss : 0.674407 +[04:45:49.621] iteration 20600 [1601.69 sec]: learning rate : 0.000050 loss : 0.794384 +[04:47:16.048] iteration 20700 [1688.12 sec]: learning rate : 0.000050 loss : 0.431320 +[04:48:42.447] iteration 20800 [1774.52 sec]: learning rate : 0.000050 loss : 0.679675 +[04:49:08.329] Epoch 9 Evaluation: +[04:50:00.134] average MSE: 0.057724107056856155 average PSNR: 27.941789023208297 average SSIM: 0.5945150346032858 +[04:51:00.834] iteration 20900 [60.64 sec]: learning rate : 0.000050 loss : 0.446814 +[04:52:27.301] iteration 21000 [147.10 sec]: learning rate : 0.000050 loss : 0.594957 +[04:53:53.710] iteration 21100 [233.51 sec]: learning rate : 0.000050 loss : 0.642850 +[04:55:20.046] iteration 21200 [319.85 sec]: learning rate : 0.000050 loss : 0.470717 +[04:56:46.479] iteration 21300 [406.28 sec]: learning rate : 0.000050 loss : 0.494588 +[04:58:12.821] iteration 21400 [492.62 sec]: learning rate : 0.000050 loss : 0.897750 +[04:59:39.230] iteration 21500 [579.03 sec]: learning rate : 0.000050 loss : 0.708780 +[05:01:05.640] iteration 21600 [665.44 sec]: learning rate : 0.000050 loss : 0.495722 +[05:02:31.985] iteration 21700 [751.79 sec]: learning rate : 0.000050 loss : 0.353979 +[05:03:58.385] iteration 21800 [838.19 sec]: learning rate : 0.000050 loss : 0.856154 +[05:05:24.788] iteration 21900 [924.59 sec]: learning rate : 0.000050 loss : 0.476958 +[05:06:51.119] iteration 22000 [1010.92 sec]: learning rate : 0.000050 loss : 1.048299 +[05:08:17.512] iteration 22100 [1097.31 sec]: learning rate : 0.000050 loss : 0.515844 +[05:09:43.923] iteration 22200 [1183.73 sec]: learning rate : 0.000050 loss : 0.550157 +[05:11:10.256] iteration 22300 [1270.06 sec]: learning rate : 0.000050 loss : 0.548238 +[05:12:36.676] iteration 22400 [1356.48 sec]: learning rate : 0.000050 loss : 0.675603 +[05:14:03.054] iteration 22500 [1442.86 sec]: learning rate : 0.000050 loss : 0.536049 +[05:15:29.377] iteration 22600 [1529.18 sec]: learning rate : 0.000050 loss : 0.453479 +[05:16:55.780] iteration 22700 [1615.58 sec]: learning rate : 0.000050 loss : 0.374681 +[05:18:22.197] iteration 22800 [1702.00 sec]: learning rate : 0.000050 loss : 0.623042 +[05:19:48.533] iteration 22900 [1788.34 sec]: learning rate : 0.000050 loss : 0.512962 +[05:19:59.740] Epoch 10 Evaluation: +[05:20:49.071] average MSE: 0.057691991329193115 average PSNR: 27.970571007483567 average SSIM: 0.5973500486422204 +[05:22:04.519] iteration 23000 [75.38 sec]: learning rate : 0.000050 loss : 0.614063 +[05:23:30.903] iteration 23100 [161.77 sec]: learning rate : 0.000050 loss : 0.480461 +[05:24:57.247] iteration 23200 [248.11 sec]: learning rate : 0.000050 loss : 0.716706 +[05:26:23.608] iteration 23300 [334.47 sec]: learning rate : 0.000050 loss : 0.782476 +[05:27:49.946] iteration 23400 [420.81 sec]: learning rate : 0.000050 loss : 0.592099 +[05:29:16.319] iteration 23500 [507.18 sec]: learning rate : 0.000050 loss : 0.686562 +[05:30:42.744] iteration 23600 [593.61 sec]: learning rate : 0.000050 loss : 0.416191 +[05:32:09.079] iteration 23700 [679.95 sec]: learning rate : 0.000050 loss : 0.748005 +[05:33:35.511] iteration 23800 [766.38 sec]: learning rate : 0.000050 loss : 0.498704 +[05:35:01.938] iteration 23900 [852.81 sec]: learning rate : 0.000050 loss : 0.683183 +[05:36:28.275] iteration 24000 [939.14 sec]: learning rate : 0.000050 loss : 0.758576 +[05:37:54.691] iteration 24100 [1025.56 sec]: learning rate : 0.000050 loss : 0.717717 +[05:39:21.040] iteration 24200 [1111.91 sec]: learning rate : 0.000050 loss : 0.947402 +[05:40:47.432] iteration 24300 [1198.30 sec]: learning rate : 0.000050 loss : 0.614142 +[05:42:13.850] iteration 24400 [1284.72 sec]: learning rate : 0.000050 loss : 0.461753 +[05:43:40.191] iteration 24500 [1371.06 sec]: learning rate : 0.000050 loss : 0.437932 +[05:45:06.583] iteration 24600 [1457.45 sec]: learning rate : 0.000050 loss : 0.685200 +[05:46:32.950] iteration 24700 [1543.82 sec]: learning rate : 0.000050 loss : 0.558957 +[05:47:59.363] iteration 24800 [1630.23 sec]: learning rate : 0.000050 loss : 0.508676 +[05:49:25.736] iteration 24900 [1716.60 sec]: learning rate : 0.000050 loss : 0.574061 +[05:50:48.602] Epoch 11 Evaluation: +[05:51:38.329] average MSE: 0.057394057512283325 average PSNR: 28.007929867087682 average SSIM: 0.5984068539967969 +[05:51:42.027] iteration 25000 [3.63 sec]: learning rate : 0.000050 loss : 0.539767 +[05:53:08.433] iteration 25100 [90.04 sec]: learning rate : 0.000050 loss : 0.338026 +[05:54:34.810] iteration 25200 [176.42 sec]: learning rate : 0.000050 loss : 0.445132 +[05:56:01.152] iteration 25300 [262.76 sec]: learning rate : 0.000050 loss : 0.717543 +[05:57:27.556] iteration 25400 [349.16 sec]: learning rate : 0.000050 loss : 0.717867 +[05:58:53.903] iteration 25500 [435.51 sec]: learning rate : 0.000050 loss : 0.520804 +[06:00:20.336] iteration 25600 [521.94 sec]: learning rate : 0.000050 loss : 1.125798 +[06:01:46.799] iteration 25700 [608.40 sec]: learning rate : 0.000050 loss : 0.483057 +[06:03:13.137] iteration 25800 [694.74 sec]: learning rate : 0.000050 loss : 0.599242 +[06:04:39.548] iteration 25900 [781.15 sec]: learning rate : 0.000050 loss : 0.712338 +[06:06:05.948] iteration 26000 [867.55 sec]: learning rate : 0.000050 loss : 0.658761 +[06:07:32.264] iteration 26100 [953.87 sec]: learning rate : 0.000050 loss : 0.654367 +[06:08:58.629] iteration 26200 [1040.23 sec]: learning rate : 0.000050 loss : 0.550846 +[06:10:25.029] iteration 26300 [1126.63 sec]: learning rate : 0.000050 loss : 0.524504 +[06:11:51.391] iteration 26400 [1213.00 sec]: learning rate : 0.000050 loss : 0.664554 +[06:13:17.843] iteration 26500 [1299.45 sec]: learning rate : 0.000050 loss : 0.504356 +[06:14:44.240] iteration 26600 [1385.85 sec]: learning rate : 0.000050 loss : 0.869313 +[06:16:10.644] iteration 26700 [1472.25 sec]: learning rate : 0.000050 loss : 0.896091 +[06:17:37.078] iteration 26800 [1558.68 sec]: learning rate : 0.000050 loss : 0.785814 +[06:19:03.431] iteration 26900 [1645.04 sec]: learning rate : 0.000050 loss : 0.814137 +[06:20:29.844] iteration 27000 [1731.45 sec]: learning rate : 0.000050 loss : 0.418923 +[06:21:38.096] Epoch 12 Evaluation: +[06:22:27.694] average MSE: 0.057263538241386414 average PSNR: 28.015108207730037 average SSIM: 0.5984137030584858 +[06:22:46.094] iteration 27100 [18.36 sec]: learning rate : 0.000050 loss : 0.771133 +[06:24:12.449] iteration 27200 [104.69 sec]: learning rate : 0.000050 loss : 0.410288 +[06:25:38.957] iteration 27300 [191.20 sec]: learning rate : 0.000050 loss : 0.618368 +[06:27:05.411] iteration 27400 [277.65 sec]: learning rate : 0.000050 loss : 0.515821 +[06:28:31.768] iteration 27500 [364.01 sec]: learning rate : 0.000050 loss : 0.464747 +[06:29:58.217] iteration 27600 [450.46 sec]: learning rate : 0.000050 loss : 0.510097 +[06:31:24.613] iteration 27700 [536.86 sec]: learning rate : 0.000050 loss : 0.820418 +[06:32:51.007] iteration 27800 [623.25 sec]: learning rate : 0.000050 loss : 0.558211 +[06:34:17.476] iteration 27900 [709.72 sec]: learning rate : 0.000050 loss : 0.610764 +[06:35:43.978] iteration 28000 [796.22 sec]: learning rate : 0.000050 loss : 0.642797 +[06:37:10.378] iteration 28100 [882.62 sec]: learning rate : 0.000050 loss : 0.664049 +[06:38:36.825] iteration 28200 [969.07 sec]: learning rate : 0.000050 loss : 0.557198 +[06:40:03.283] iteration 28300 [1055.53 sec]: learning rate : 0.000050 loss : 0.865703 +[06:41:29.678] iteration 28400 [1141.92 sec]: learning rate : 0.000050 loss : 0.783866 +[06:42:56.102] iteration 28500 [1228.35 sec]: learning rate : 0.000050 loss : 0.271763 +[06:44:22.567] iteration 28600 [1314.81 sec]: learning rate : 0.000050 loss : 0.440876 +[06:45:48.960] iteration 28700 [1401.20 sec]: learning rate : 0.000050 loss : 0.553919 +[06:47:15.361] iteration 28800 [1487.60 sec]: learning rate : 0.000050 loss : 0.593385 +[06:48:41.818] iteration 28900 [1574.06 sec]: learning rate : 0.000050 loss : 0.815849 +[06:50:08.200] iteration 29000 [1660.44 sec]: learning rate : 0.000050 loss : 0.837029 +[06:51:34.613] iteration 29100 [1746.86 sec]: learning rate : 0.000050 loss : 0.494794 +[06:52:28.175] Epoch 13 Evaluation: +[06:53:19.969] average MSE: 0.05717639997601509 average PSNR: 28.03786422808609 average SSIM: 0.5987058747973671 +[06:53:53.148] iteration 29200 [33.12 sec]: learning rate : 0.000050 loss : 0.490666 +[06:55:19.535] iteration 29300 [119.50 sec]: learning rate : 0.000050 loss : 0.586870 +[06:56:45.978] iteration 29400 [205.95 sec]: learning rate : 0.000050 loss : 0.675699 +[06:58:12.434] iteration 29500 [292.40 sec]: learning rate : 0.000050 loss : 0.460325 +[06:59:38.823] iteration 29600 [378.79 sec]: learning rate : 0.000050 loss : 0.934071 +[07:01:05.302] iteration 29700 [465.27 sec]: learning rate : 0.000050 loss : 0.753681 +[07:02:31.709] iteration 29800 [551.68 sec]: learning rate : 0.000050 loss : 0.460279 +[07:03:58.157] iteration 29900 [638.12 sec]: learning rate : 0.000050 loss : 0.570196 +[07:05:24.649] iteration 30000 [724.62 sec]: learning rate : 0.000050 loss : 0.453942 +[07:06:51.054] iteration 30100 [811.02 sec]: learning rate : 0.000050 loss : 0.440703 +[07:08:17.531] iteration 30200 [897.50 sec]: learning rate : 0.000050 loss : 0.560504 +[07:09:43.923] iteration 30300 [983.90 sec]: learning rate : 0.000050 loss : 0.586051 +[07:11:10.405] iteration 30400 [1070.37 sec]: learning rate : 0.000050 loss : 0.618881 +[07:12:36.890] iteration 30500 [1156.86 sec]: learning rate : 0.000050 loss : 0.532264 +[07:14:03.294] iteration 30600 [1243.26 sec]: learning rate : 0.000050 loss : 0.720281 +[07:15:29.760] iteration 30700 [1329.73 sec]: learning rate : 0.000050 loss : 0.360706 +[07:16:56.252] iteration 30800 [1416.22 sec]: learning rate : 0.000050 loss : 0.512196 +[07:18:22.668] iteration 30900 [1502.64 sec]: learning rate : 0.000050 loss : 0.748604 +[07:19:49.144] iteration 31000 [1589.11 sec]: learning rate : 0.000050 loss : 0.396364 +[07:21:15.619] iteration 31100 [1675.59 sec]: learning rate : 0.000050 loss : 0.676904 +[07:22:42.027] iteration 31200 [1762.00 sec]: learning rate : 0.000050 loss : 0.609989 +[07:23:20.959] Epoch 14 Evaluation: +[07:24:12.535] average MSE: 0.05736875534057617 average PSNR: 28.019977493322347 average SSIM: 0.598895357029985 +[07:25:00.297] iteration 31300 [47.70 sec]: learning rate : 0.000050 loss : 0.719089 +[07:26:26.809] iteration 31400 [134.21 sec]: learning rate : 0.000050 loss : 0.607873 +[07:27:53.222] iteration 31500 [220.62 sec]: learning rate : 0.000050 loss : 0.495677 +[07:29:19.700] iteration 31600 [307.10 sec]: learning rate : 0.000050 loss : 0.413016 +[07:30:46.122] iteration 31700 [393.52 sec]: learning rate : 0.000050 loss : 0.911431 +[07:32:12.559] iteration 31800 [479.96 sec]: learning rate : 0.000050 loss : 0.769847 +[07:33:39.040] iteration 31900 [566.44 sec]: learning rate : 0.000050 loss : 0.652184 +[07:35:05.459] iteration 32000 [652.86 sec]: learning rate : 0.000050 loss : 0.547674 +[07:36:31.898] iteration 32100 [739.30 sec]: learning rate : 0.000050 loss : 0.573709 +[07:37:58.317] iteration 32200 [825.72 sec]: learning rate : 0.000050 loss : 0.363052 +[07:39:24.788] iteration 32300 [912.19 sec]: learning rate : 0.000050 loss : 0.841415 +[07:40:51.254] iteration 32400 [998.66 sec]: learning rate : 0.000050 loss : 0.641225 +[07:42:17.680] iteration 32500 [1085.08 sec]: learning rate : 0.000050 loss : 0.565872 +[07:43:44.141] iteration 32600 [1171.54 sec]: learning rate : 0.000050 loss : 0.520618 +[07:45:10.563] iteration 32700 [1257.97 sec]: learning rate : 0.000050 loss : 0.621680 +[07:46:36.998] iteration 32800 [1344.40 sec]: learning rate : 0.000050 loss : 0.717874 +[07:48:03.520] iteration 32900 [1430.92 sec]: learning rate : 0.000050 loss : 0.483343 +[07:49:29.909] iteration 33000 [1517.31 sec]: learning rate : 0.000050 loss : 0.581507 +[07:50:56.334] iteration 33100 [1603.74 sec]: learning rate : 0.000050 loss : 0.767824 +[07:52:22.714] iteration 33200 [1690.12 sec]: learning rate : 0.000050 loss : 0.499948 +[07:53:49.190] iteration 33300 [1776.59 sec]: learning rate : 0.000050 loss : 0.707742 +[07:54:13.353] Epoch 15 Evaluation: +[07:55:03.912] average MSE: 0.05735837295651436 average PSNR: 28.03474875788648 average SSIM: 0.5992890938480476 +[07:56:06.471] iteration 33400 [62.50 sec]: learning rate : 0.000050 loss : 0.445151 +[07:57:32.855] iteration 33500 [148.88 sec]: learning rate : 0.000050 loss : 0.250337 +[07:58:59.325] iteration 33600 [235.35 sec]: learning rate : 0.000050 loss : 0.474204 +[08:00:25.796] iteration 33700 [321.82 sec]: learning rate : 0.000050 loss : 0.867915 +[08:01:52.176] iteration 33800 [408.20 sec]: learning rate : 0.000050 loss : 0.559357 +[08:03:18.648] iteration 33900 [494.67 sec]: learning rate : 0.000050 loss : 0.688633 +[08:04:45.105] iteration 34000 [581.13 sec]: learning rate : 0.000050 loss : 0.574571 +[08:06:11.479] iteration 34100 [667.50 sec]: learning rate : 0.000050 loss : 0.713666 +[08:07:37.927] iteration 34200 [753.95 sec]: learning rate : 0.000050 loss : 0.487302 +[08:09:04.302] iteration 34300 [840.33 sec]: learning rate : 0.000050 loss : 0.699434 +[08:10:30.701] iteration 34400 [926.73 sec]: learning rate : 0.000050 loss : 0.561770 +[08:11:57.127] iteration 34500 [1013.15 sec]: learning rate : 0.000050 loss : 0.723415 +[08:13:23.489] iteration 34600 [1099.51 sec]: learning rate : 0.000050 loss : 0.423014 +[08:14:49.878] iteration 34700 [1185.90 sec]: learning rate : 0.000050 loss : 0.565815 +[08:16:16.266] iteration 34800 [1272.29 sec]: learning rate : 0.000050 loss : 0.642617 +[08:17:42.584] iteration 34900 [1358.61 sec]: learning rate : 0.000050 loss : 0.709588 +[08:19:08.983] iteration 35000 [1445.01 sec]: learning rate : 0.000050 loss : 0.358515 +[08:20:35.345] iteration 35100 [1531.37 sec]: learning rate : 0.000050 loss : 0.476024 +[08:22:01.658] iteration 35200 [1617.68 sec]: learning rate : 0.000050 loss : 0.422620 +[08:23:28.058] iteration 35300 [1704.08 sec]: learning rate : 0.000050 loss : 0.454224 +[08:24:54.441] iteration 35400 [1790.47 sec]: learning rate : 0.000050 loss : 0.359873 +[08:25:03.905] Epoch 16 Evaluation: +[08:25:54.688] average MSE: 0.05684836208820343 average PSNR: 28.073393303042337 average SSIM: 0.6002361996530216 +[08:27:11.720] iteration 35500 [76.97 sec]: learning rate : 0.000050 loss : 0.816112 +[08:28:38.135] iteration 35600 [163.38 sec]: learning rate : 0.000050 loss : 0.562299 +[08:30:04.486] iteration 35700 [249.73 sec]: learning rate : 0.000050 loss : 0.587143 +[08:31:30.786] iteration 35800 [336.03 sec]: learning rate : 0.000050 loss : 0.578356 +[08:32:57.162] iteration 35900 [422.41 sec]: learning rate : 0.000050 loss : 0.661439 +[08:34:23.519] iteration 36000 [508.77 sec]: learning rate : 0.000050 loss : 0.516910 +[08:35:49.877] iteration 36100 [595.12 sec]: learning rate : 0.000050 loss : 0.703663 +[08:37:16.289] iteration 36200 [681.54 sec]: learning rate : 0.000050 loss : 0.786527 +[08:38:42.642] iteration 36300 [767.89 sec]: learning rate : 0.000050 loss : 0.558988 +[08:40:09.016] iteration 36400 [854.26 sec]: learning rate : 0.000050 loss : 0.688048 +[08:41:35.365] iteration 36500 [940.61 sec]: learning rate : 0.000050 loss : 0.738928 +[08:43:01.786] iteration 36600 [1027.03 sec]: learning rate : 0.000050 loss : 0.435807 +[08:44:28.203] iteration 36700 [1113.45 sec]: learning rate : 0.000050 loss : 0.416730 +[08:45:54.581] iteration 36800 [1199.83 sec]: learning rate : 0.000050 loss : 0.648582 +[08:47:21.022] iteration 36900 [1286.27 sec]: learning rate : 0.000050 loss : 0.440464 +[08:48:47.437] iteration 37000 [1372.68 sec]: learning rate : 0.000050 loss : 0.369077 +[08:50:13.869] iteration 37100 [1459.11 sec]: learning rate : 0.000050 loss : 0.619662 +[08:51:40.346] iteration 37200 [1545.59 sec]: learning rate : 0.000050 loss : 0.764441 +[08:53:06.751] iteration 37300 [1632.00 sec]: learning rate : 0.000050 loss : 0.692722 +[08:54:33.198] iteration 37400 [1718.44 sec]: learning rate : 0.000050 loss : 0.704968 +[08:55:54.420] Epoch 17 Evaluation: +[08:56:44.008] average MSE: 0.056925296783447266 average PSNR: 28.0674683580224 average SSIM: 0.6006933518876415 +[08:56:49.430] iteration 37500 [5.36 sec]: learning rate : 0.000050 loss : 0.712171 +[08:58:15.954] iteration 37600 [91.88 sec]: learning rate : 0.000050 loss : 0.777882 +[08:59:42.457] iteration 37700 [178.39 sec]: learning rate : 0.000050 loss : 0.678342 +[09:01:08.903] iteration 37800 [264.83 sec]: learning rate : 0.000050 loss : 0.458404 +[09:02:35.426] iteration 37900 [351.36 sec]: learning rate : 0.000050 loss : 0.869741 +[09:04:01.918] iteration 38000 [437.85 sec]: learning rate : 0.000050 loss : 0.428745 +[09:05:28.361] iteration 38100 [524.29 sec]: learning rate : 0.000050 loss : 0.553082 +[09:06:54.858] iteration 38200 [610.79 sec]: learning rate : 0.000050 loss : 0.672991 +[09:08:21.298] iteration 38300 [697.23 sec]: learning rate : 0.000050 loss : 0.406992 +[09:09:47.804] iteration 38400 [783.73 sec]: learning rate : 0.000050 loss : 0.532689 +[09:11:14.317] iteration 38500 [870.25 sec]: learning rate : 0.000050 loss : 1.024871 +[09:12:40.735] iteration 38600 [956.66 sec]: learning rate : 0.000050 loss : 0.685313 +[09:14:07.173] iteration 38700 [1043.10 sec]: learning rate : 0.000050 loss : 0.486882 +[09:15:33.609] iteration 38800 [1129.54 sec]: learning rate : 0.000050 loss : 0.835683 +[09:17:00.054] iteration 38900 [1215.98 sec]: learning rate : 0.000050 loss : 0.700505 +[09:18:26.535] iteration 39000 [1302.47 sec]: learning rate : 0.000050 loss : 0.550026 +[09:19:52.941] iteration 39100 [1388.87 sec]: learning rate : 0.000050 loss : 0.795890 +[09:21:19.384] iteration 39200 [1475.31 sec]: learning rate : 0.000050 loss : 0.322924 +[09:22:45.866] iteration 39300 [1561.80 sec]: learning rate : 0.000050 loss : 0.621844 +[09:24:12.246] iteration 39400 [1648.18 sec]: learning rate : 0.000050 loss : 0.402305 +[09:25:38.653] iteration 39500 [1734.58 sec]: learning rate : 0.000050 loss : 0.817296 +[09:26:45.150] Epoch 18 Evaluation: +[09:27:35.592] average MSE: 0.05631484463810921 average PSNR: 28.117416817037235 average SSIM: 0.6001188954795013 +[09:27:55.784] iteration 39600 [20.13 sec]: learning rate : 0.000050 loss : 0.909721 +[09:29:22.122] iteration 39700 [106.47 sec]: learning rate : 0.000050 loss : 0.579749 +[09:30:48.636] iteration 39800 [192.98 sec]: learning rate : 0.000050 loss : 0.804316 +[09:32:15.041] iteration 39900 [279.39 sec]: learning rate : 0.000050 loss : 0.635820 +[09:33:41.497] iteration 40000 [365.84 sec]: learning rate : 0.000013 loss : 0.363064 +[09:33:41.663] save model to model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion/iter_40000.pth +[09:35:08.082] iteration 40100 [452.43 sec]: learning rate : 0.000025 loss : 0.376859 +[09:36:34.471] iteration 40200 [538.82 sec]: learning rate : 0.000025 loss : 0.490092 +[09:38:00.925] iteration 40300 [625.27 sec]: learning rate : 0.000025 loss : 0.335904 +[09:39:27.381] iteration 40400 [711.73 sec]: learning rate : 0.000025 loss : 0.623007 +[09:40:53.773] iteration 40500 [798.12 sec]: learning rate : 0.000025 loss : 0.551527 +[09:42:20.234] iteration 40600 [884.58 sec]: learning rate : 0.000025 loss : 0.460647 +[09:43:46.684] iteration 40700 [971.03 sec]: learning rate : 0.000025 loss : 0.656664 +[09:45:13.063] iteration 40800 [1057.41 sec]: learning rate : 0.000025 loss : 0.732658 +[09:46:39.567] iteration 40900 [1143.91 sec]: learning rate : 0.000025 loss : 0.530695 +[09:48:06.005] iteration 41000 [1230.35 sec]: learning rate : 0.000025 loss : 0.315458 +[09:49:32.374] iteration 41100 [1316.72 sec]: learning rate : 0.000025 loss : 0.759369 +[09:50:58.787] iteration 41200 [1403.13 sec]: learning rate : 0.000025 loss : 0.686400 +[09:52:25.198] iteration 41300 [1489.54 sec]: learning rate : 0.000025 loss : 0.590403 +[09:53:51.581] iteration 41400 [1575.93 sec]: learning rate : 0.000025 loss : 0.770763 +[09:55:18.013] iteration 41500 [1662.36 sec]: learning rate : 0.000025 loss : 0.547469 +[09:56:44.471] iteration 41600 [1748.82 sec]: learning rate : 0.000025 loss : 0.470961 +[09:57:36.289] Epoch 19 Evaluation: +[09:58:28.135] average MSE: 0.05601537600159645 average PSNR: 28.15529984338068 average SSIM: 0.6014081929508498 +[09:59:02.962] iteration 41700 [34.76 sec]: learning rate : 0.000025 loss : 0.525289 +[10:00:29.492] iteration 41800 [121.29 sec]: learning rate : 0.000025 loss : 0.671515 +[10:01:55.888] iteration 41900 [207.69 sec]: learning rate : 0.000025 loss : 0.814825 +[10:03:22.357] iteration 42000 [294.16 sec]: learning rate : 0.000025 loss : 0.617432 +[10:04:48.839] iteration 42100 [380.64 sec]: learning rate : 0.000025 loss : 0.618249 +[10:06:15.224] iteration 42200 [467.03 sec]: learning rate : 0.000025 loss : 0.473086 +[10:07:41.734] iteration 42300 [553.55 sec]: learning rate : 0.000025 loss : 1.104049 +[10:09:08.204] iteration 42400 [640.01 sec]: learning rate : 0.000025 loss : 0.691411 +[10:10:34.629] iteration 42500 [726.43 sec]: learning rate : 0.000025 loss : 0.718796 +[10:12:01.109] iteration 42600 [812.93 sec]: learning rate : 0.000025 loss : 0.472593 +[10:13:27.597] iteration 42700 [899.40 sec]: learning rate : 0.000025 loss : 0.359976 +[10:14:54.011] iteration 42800 [985.82 sec]: learning rate : 0.000025 loss : 0.829062 +[10:16:20.524] iteration 42900 [1072.33 sec]: learning rate : 0.000025 loss : 0.429105 +[10:17:46.950] iteration 43000 [1158.75 sec]: learning rate : 0.000025 loss : 0.363114 +[10:19:13.455] iteration 43100 [1245.26 sec]: learning rate : 0.000025 loss : 0.432581 +[10:20:39.934] iteration 43200 [1331.74 sec]: learning rate : 0.000025 loss : 0.553386 +[10:22:06.358] iteration 43300 [1418.16 sec]: learning rate : 0.000025 loss : 0.820087 +[10:23:32.794] iteration 43400 [1504.60 sec]: learning rate : 0.000025 loss : 0.595191 +[10:24:59.293] iteration 43500 [1591.10 sec]: learning rate : 0.000025 loss : 0.495027 +[10:26:25.707] iteration 43600 [1677.51 sec]: learning rate : 0.000025 loss : 0.619841 +[10:27:52.163] iteration 43700 [1763.97 sec]: learning rate : 0.000025 loss : 0.592231 +[10:28:29.294] Epoch 20 Evaluation: +[10:29:18.708] average MSE: 0.056624483317136765 average PSNR: 28.09700377970651 average SSIM: 0.6018976767394358 +[10:30:08.321] iteration 43800 [49.55 sec]: learning rate : 0.000025 loss : 0.593979 +[10:31:34.695] iteration 43900 [135.92 sec]: learning rate : 0.000025 loss : 0.663432 +[10:33:01.189] iteration 44000 [222.42 sec]: learning rate : 0.000025 loss : 0.402033 +[10:34:27.654] iteration 44100 [308.88 sec]: learning rate : 0.000025 loss : 0.504831 +[10:35:54.076] iteration 44200 [395.31 sec]: learning rate : 0.000025 loss : 0.989602 +[10:37:20.535] iteration 44300 [481.76 sec]: learning rate : 0.000025 loss : 0.358256 +[10:38:46.989] iteration 44400 [568.22 sec]: learning rate : 0.000025 loss : 0.594282 +[10:40:13.393] iteration 44500 [654.62 sec]: learning rate : 0.000025 loss : 0.659238 +[10:41:39.843] iteration 44600 [741.07 sec]: learning rate : 0.000025 loss : 0.565837 +[10:43:06.314] iteration 44700 [827.61 sec]: learning rate : 0.000025 loss : 0.360541 +[10:44:32.710] iteration 44800 [913.94 sec]: learning rate : 0.000025 loss : 0.922050 +[10:45:59.197] iteration 44900 [1000.43 sec]: learning rate : 0.000025 loss : 0.984147 +[10:47:25.632] iteration 45000 [1086.86 sec]: learning rate : 0.000025 loss : 0.438175 +[10:48:52.087] iteration 45100 [1173.32 sec]: learning rate : 0.000025 loss : 0.763101 +[10:50:18.505] iteration 45200 [1259.74 sec]: learning rate : 0.000025 loss : 0.656497 +[10:51:44.907] iteration 45300 [1346.14 sec]: learning rate : 0.000025 loss : 0.442770 +[10:53:11.365] iteration 45400 [1432.60 sec]: learning rate : 0.000025 loss : 0.447421 +[10:54:37.809] iteration 45500 [1519.04 sec]: learning rate : 0.000025 loss : 0.647015 +[10:56:04.224] iteration 45600 [1605.45 sec]: learning rate : 0.000025 loss : 0.710665 +[10:57:30.667] iteration 45700 [1691.90 sec]: learning rate : 0.000025 loss : 0.579126 +[10:58:57.126] iteration 45800 [1778.36 sec]: learning rate : 0.000025 loss : 0.579905 +[10:59:19.557] Epoch 21 Evaluation: +[11:00:11.100] average MSE: 0.05617809295654297 average PSNR: 28.14213094091348 average SSIM: 0.6015142316107447 +[11:01:15.282] iteration 45900 [64.12 sec]: learning rate : 0.000025 loss : 0.645760 +[11:02:41.809] iteration 46000 [150.65 sec]: learning rate : 0.000025 loss : 0.419018 +[11:04:08.271] iteration 46100 [237.11 sec]: learning rate : 0.000025 loss : 0.447862 +[11:05:34.670] iteration 46200 [323.51 sec]: learning rate : 0.000025 loss : 0.800858 +[11:07:01.129] iteration 46300 [409.97 sec]: learning rate : 0.000025 loss : 0.659421 +[11:08:27.584] iteration 46400 [496.42 sec]: learning rate : 0.000025 loss : 0.646652 +[11:09:53.992] iteration 46500 [582.83 sec]: learning rate : 0.000025 loss : 0.720503 +[11:11:20.448] iteration 46600 [669.29 sec]: learning rate : 0.000025 loss : 0.491437 +[11:12:46.950] iteration 46700 [755.79 sec]: learning rate : 0.000025 loss : 0.460427 +[11:14:13.339] iteration 46800 [842.18 sec]: learning rate : 0.000025 loss : 0.591604 +[11:15:39.794] iteration 46900 [928.63 sec]: learning rate : 0.000025 loss : 0.604769 +[11:17:06.265] iteration 47000 [1015.10 sec]: learning rate : 0.000025 loss : 0.480521 +[11:18:32.689] iteration 47100 [1101.53 sec]: learning rate : 0.000025 loss : 0.897655 +[11:19:59.148] iteration 47200 [1187.99 sec]: learning rate : 0.000025 loss : 0.528064 +[11:21:25.558] iteration 47300 [1274.40 sec]: learning rate : 0.000025 loss : 0.610464 +[11:22:52.027] iteration 47400 [1360.86 sec]: learning rate : 0.000025 loss : 0.472183 +[11:24:18.511] iteration 47500 [1447.35 sec]: learning rate : 0.000025 loss : 0.780611 +[11:25:44.919] iteration 47600 [1533.76 sec]: learning rate : 0.000025 loss : 0.604246 +[11:27:11.349] iteration 47700 [1620.19 sec]: learning rate : 0.000025 loss : 0.454554 +[11:28:37.784] iteration 47800 [1706.62 sec]: learning rate : 0.000025 loss : 0.695632 +[11:30:04.259] iteration 47900 [1793.10 sec]: learning rate : 0.000025 loss : 0.625651 +[11:30:12.015] Epoch 22 Evaluation: +[11:31:03.549] average MSE: 0.056389447301626205 average PSNR: 28.127224401290913 average SSIM: 0.6013819873161014 +[11:32:22.544] iteration 48000 [78.93 sec]: learning rate : 0.000025 loss : 0.774685 +[11:33:48.955] iteration 48100 [165.34 sec]: learning rate : 0.000025 loss : 0.707850 +[11:35:15.464] iteration 48200 [251.85 sec]: learning rate : 0.000025 loss : 0.563945 +[11:36:41.933] iteration 48300 [338.32 sec]: learning rate : 0.000025 loss : 0.480162 +[11:38:08.359] iteration 48400 [424.75 sec]: learning rate : 0.000025 loss : 0.679652 +[11:39:34.863] iteration 48500 [511.25 sec]: learning rate : 0.000025 loss : 0.626497 +[11:41:01.293] iteration 48600 [597.68 sec]: learning rate : 0.000025 loss : 0.573296 +[11:42:27.728] iteration 48700 [684.12 sec]: learning rate : 0.000025 loss : 0.444118 +[11:43:54.193] iteration 48800 [770.58 sec]: learning rate : 0.000025 loss : 0.329037 +[11:45:20.694] iteration 48900 [857.08 sec]: learning rate : 0.000025 loss : 0.463352 +[11:46:47.127] iteration 49000 [943.52 sec]: learning rate : 0.000025 loss : 0.772944 +[11:48:13.635] iteration 49100 [1030.02 sec]: learning rate : 0.000025 loss : 0.575231 +[11:49:40.072] iteration 49200 [1116.46 sec]: learning rate : 0.000025 loss : 0.396059 +[11:51:06.583] iteration 49300 [1202.97 sec]: learning rate : 0.000025 loss : 0.503983 +[11:52:33.060] iteration 49400 [1289.45 sec]: learning rate : 0.000025 loss : 0.443574 +[11:53:59.504] iteration 49500 [1375.89 sec]: learning rate : 0.000025 loss : 0.438454 +[11:55:25.984] iteration 49600 [1462.37 sec]: learning rate : 0.000025 loss : 0.560583 +[11:56:52.475] iteration 49700 [1548.86 sec]: learning rate : 0.000025 loss : 0.687025 +[11:58:18.900] iteration 49800 [1635.29 sec]: learning rate : 0.000025 loss : 0.718234 +[11:59:45.402] iteration 49900 [1721.79 sec]: learning rate : 0.000025 loss : 0.551758 +[12:01:04.891] Epoch 23 Evaluation: +[12:01:54.274] average MSE: 0.05619407817721367 average PSNR: 28.137199058054478 average SSIM: 0.6022103144897615 +[12:02:01.523] iteration 50000 [7.19 sec]: learning rate : 0.000025 loss : 0.402243 +[12:03:27.916] iteration 50100 [93.58 sec]: learning rate : 0.000025 loss : 0.940795 +[12:04:54.365] iteration 50200 [180.03 sec]: learning rate : 0.000025 loss : 0.735754 +[12:06:20.787] iteration 50300 [266.45 sec]: learning rate : 0.000025 loss : 0.681577 +[12:07:47.288] iteration 50400 [352.95 sec]: learning rate : 0.000025 loss : 0.627277 +[12:09:13.808] iteration 50500 [439.47 sec]: learning rate : 0.000025 loss : 0.449623 +[12:10:40.273] iteration 50600 [525.94 sec]: learning rate : 0.000025 loss : 0.551163 +[12:12:06.806] iteration 50700 [612.47 sec]: learning rate : 0.000025 loss : 0.697463 +[12:13:33.295] iteration 50800 [698.96 sec]: learning rate : 0.000025 loss : 0.463599 +[12:14:59.766] iteration 50900 [785.43 sec]: learning rate : 0.000025 loss : 0.427070 +[12:16:26.282] iteration 51000 [871.95 sec]: learning rate : 0.000025 loss : 0.528331 +[12:17:52.796] iteration 51100 [958.46 sec]: learning rate : 0.000025 loss : 0.660806 +[12:19:19.253] iteration 51200 [1044.92 sec]: learning rate : 0.000025 loss : 0.335331 +[12:20:45.800] iteration 51300 [1131.46 sec]: learning rate : 0.000025 loss : 0.647946 +[12:22:12.325] iteration 51400 [1217.99 sec]: learning rate : 0.000025 loss : 0.608254 +[12:23:38.799] iteration 51500 [1304.46 sec]: learning rate : 0.000025 loss : 0.552035 +[12:25:05.358] iteration 51600 [1391.02 sec]: learning rate : 0.000025 loss : 0.546658 +[12:26:31.879] iteration 51700 [1477.54 sec]: learning rate : 0.000025 loss : 0.312416 +[12:27:58.357] iteration 51800 [1564.02 sec]: learning rate : 0.000025 loss : 0.613077 +[12:29:24.882] iteration 51900 [1650.55 sec]: learning rate : 0.000025 loss : 0.546069 +[12:30:51.428] iteration 52000 [1737.09 sec]: learning rate : 0.000025 loss : 0.644281 +[12:31:56.299] Epoch 24 Evaluation: +[12:32:45.779] average MSE: 0.056167542934417725 average PSNR: 28.151376733469196 average SSIM: 0.6027998668160449 +[12:33:07.642] iteration 52100 [21.80 sec]: learning rate : 0.000025 loss : 0.657257 +[12:34:34.246] iteration 52200 [108.40 sec]: learning rate : 0.000025 loss : 0.779841 +[12:36:00.715] iteration 52300 [194.87 sec]: learning rate : 0.000025 loss : 0.405920 +[12:37:27.321] iteration 52400 [281.48 sec]: learning rate : 0.000025 loss : 0.605536 +[12:38:53.830] iteration 52500 [367.99 sec]: learning rate : 0.000025 loss : 0.558868 +[12:40:20.342] iteration 52600 [454.50 sec]: learning rate : 0.000025 loss : 0.376798 +[12:41:46.918] iteration 52700 [541.08 sec]: learning rate : 0.000025 loss : 0.565831 +[12:43:13.438] iteration 52800 [627.60 sec]: learning rate : 0.000025 loss : 0.542409 +[12:44:39.945] iteration 52900 [714.10 sec]: learning rate : 0.000025 loss : 0.641804 +[12:46:06.528] iteration 53000 [800.69 sec]: learning rate : 0.000025 loss : 0.884721 +[12:47:33.040] iteration 53100 [887.20 sec]: learning rate : 0.000025 loss : 0.722085 +[12:48:59.585] iteration 53200 [973.74 sec]: learning rate : 0.000025 loss : 0.506006 +[12:50:26.162] iteration 53300 [1060.32 sec]: learning rate : 0.000025 loss : 0.625562 +[12:51:52.682] iteration 53400 [1146.84 sec]: learning rate : 0.000025 loss : 0.262767 +[12:53:19.200] iteration 53500 [1233.36 sec]: learning rate : 0.000025 loss : 0.746783 +[12:54:45.698] iteration 53600 [1319.86 sec]: learning rate : 0.000025 loss : 0.692761 +[12:56:12.297] iteration 53700 [1406.46 sec]: learning rate : 0.000025 loss : 0.676696 +[12:57:38.820] iteration 53800 [1492.98 sec]: learning rate : 0.000025 loss : 0.746134 +[12:59:05.316] iteration 53900 [1579.47 sec]: learning rate : 0.000025 loss : 0.662841 +[13:00:31.873] iteration 54000 [1666.03 sec]: learning rate : 0.000025 loss : 0.523787 +[13:01:58.428] iteration 54100 [1752.59 sec]: learning rate : 0.000025 loss : 0.410977 +[13:02:48.647] Epoch 25 Evaluation: +[13:03:40.321] average MSE: 0.05629804730415344 average PSNR: 28.133603538526234 average SSIM: 0.6023510864342567 +[13:04:16.922] iteration 54200 [36.54 sec]: learning rate : 0.000025 loss : 0.837376 +[13:05:43.529] iteration 54300 [123.16 sec]: learning rate : 0.000025 loss : 0.456987 +[13:07:10.031] iteration 54400 [209.65 sec]: learning rate : 0.000025 loss : 0.827093 +[13:08:36.560] iteration 54500 [296.18 sec]: learning rate : 0.000025 loss : 0.655596 +[13:10:03.131] iteration 54600 [382.75 sec]: learning rate : 0.000025 loss : 0.302919 +[13:11:29.660] iteration 54700 [469.28 sec]: learning rate : 0.000025 loss : 0.735057 +[13:12:56.244] iteration 54800 [555.86 sec]: learning rate : 0.000025 loss : 0.690371 +[13:14:22.749] iteration 54900 [642.37 sec]: learning rate : 0.000025 loss : 0.429313 +[13:15:49.339] iteration 55000 [728.96 sec]: learning rate : 0.000025 loss : 0.694592 +[13:17:15.895] iteration 55100 [815.51 sec]: learning rate : 0.000025 loss : 0.566186 +[13:18:42.401] iteration 55200 [902.02 sec]: learning rate : 0.000025 loss : 0.693281 +[13:20:08.972] iteration 55300 [988.59 sec]: learning rate : 0.000025 loss : 0.461440 +[13:21:35.474] iteration 55400 [1075.09 sec]: learning rate : 0.000025 loss : 0.679073 +[13:23:02.032] iteration 55500 [1161.65 sec]: learning rate : 0.000025 loss : 0.982108 +[13:24:28.613] iteration 55600 [1248.23 sec]: learning rate : 0.000025 loss : 0.424559 +[13:25:55.083] iteration 55700 [1334.70 sec]: learning rate : 0.000025 loss : 0.555962 +[13:27:21.644] iteration 55800 [1421.26 sec]: learning rate : 0.000025 loss : 0.553236 +[13:28:48.170] iteration 55900 [1507.79 sec]: learning rate : 0.000025 loss : 0.294827 +[13:30:14.660] iteration 56000 [1594.28 sec]: learning rate : 0.000025 loss : 0.384868 +[13:31:41.198] iteration 56100 [1680.81 sec]: learning rate : 0.000025 loss : 0.484402 +[13:33:07.754] iteration 56200 [1767.37 sec]: learning rate : 0.000025 loss : 0.449657 +[13:33:43.173] Epoch 26 Evaluation: +[13:34:32.592] average MSE: 0.056351084262132645 average PSNR: 28.132602312096125 average SSIM: 0.6011753728732745 +[13:35:23.851] iteration 56300 [51.20 sec]: learning rate : 0.000025 loss : 0.627679 +[13:36:50.419] iteration 56400 [137.77 sec]: learning rate : 0.000025 loss : 0.508496 +[13:38:16.877] iteration 56500 [224.22 sec]: learning rate : 0.000025 loss : 0.432239 +[13:39:43.379] iteration 56600 [310.73 sec]: learning rate : 0.000025 loss : 0.511520 +[13:41:09.931] iteration 56700 [397.28 sec]: learning rate : 0.000025 loss : 0.634993 +[13:42:36.359] iteration 56800 [483.71 sec]: learning rate : 0.000025 loss : 0.582613 +[13:44:02.842] iteration 56900 [570.19 sec]: learning rate : 0.000025 loss : 0.494479 +[13:45:29.270] iteration 57000 [656.62 sec]: learning rate : 0.000025 loss : 0.779464 +[13:46:55.740] iteration 57100 [743.09 sec]: learning rate : 0.000025 loss : 0.731628 +[13:48:22.234] iteration 57200 [829.58 sec]: learning rate : 0.000025 loss : 0.485325 +[13:49:48.655] iteration 57300 [916.00 sec]: learning rate : 0.000025 loss : 0.590943 +[13:51:15.130] iteration 57400 [1002.48 sec]: learning rate : 0.000025 loss : 0.615251 +[13:52:41.633] iteration 57500 [1088.98 sec]: learning rate : 0.000025 loss : 0.486397 +[13:54:08.044] iteration 57600 [1175.39 sec]: learning rate : 0.000025 loss : 1.037342 +[13:55:34.528] iteration 57700 [1261.87 sec]: learning rate : 0.000025 loss : 0.733652 +[13:57:00.972] iteration 57800 [1348.32 sec]: learning rate : 0.000025 loss : 0.357618 +[13:58:27.468] iteration 57900 [1434.81 sec]: learning rate : 0.000025 loss : 0.470172 +[13:59:53.984] iteration 58000 [1521.33 sec]: learning rate : 0.000025 loss : 0.812234 +[14:01:20.426] iteration 58100 [1607.77 sec]: learning rate : 0.000025 loss : 0.585702 +[14:02:46.923] iteration 58200 [1694.27 sec]: learning rate : 0.000025 loss : 0.698750 +[14:04:13.426] iteration 58300 [1780.77 sec]: learning rate : 0.000025 loss : 0.529907 +[14:04:34.145] Epoch 27 Evaluation: +[14:05:26.319] average MSE: 0.055945634841918945 average PSNR: 28.1621294474958 average SSIM: 0.6014718104991212 +[14:06:32.259] iteration 58400 [65.88 sec]: learning rate : 0.000025 loss : 0.653305 +[14:07:58.787] iteration 58500 [152.40 sec]: learning rate : 0.000025 loss : 0.490572 +[14:09:25.293] iteration 58600 [238.91 sec]: learning rate : 0.000025 loss : 0.523372 +[14:10:51.742] iteration 58700 [325.36 sec]: learning rate : 0.000025 loss : 0.751589 +[14:12:18.282] iteration 58800 [411.90 sec]: learning rate : 0.000025 loss : 0.727068 +[14:13:44.785] iteration 58900 [498.40 sec]: learning rate : 0.000025 loss : 0.856851 +[14:15:11.255] iteration 59000 [584.87 sec]: learning rate : 0.000025 loss : 0.934135 +[14:16:37.777] iteration 59100 [671.40 sec]: learning rate : 0.000025 loss : 0.638801 +[14:18:04.309] iteration 59200 [757.93 sec]: learning rate : 0.000025 loss : 0.501801 +[14:19:30.790] iteration 59300 [844.41 sec]: learning rate : 0.000025 loss : 0.610891 +[14:20:57.291] iteration 59400 [930.91 sec]: learning rate : 0.000025 loss : 0.710253 +[14:22:23.773] iteration 59500 [1017.39 sec]: learning rate : 0.000025 loss : 0.519027 +[14:23:50.344] iteration 59600 [1103.96 sec]: learning rate : 0.000025 loss : 0.527423 +[14:25:16.822] iteration 59700 [1190.44 sec]: learning rate : 0.000025 loss : 0.451692 +[14:26:43.292] iteration 59800 [1276.91 sec]: learning rate : 0.000025 loss : 0.579170 +[14:28:09.839] iteration 59900 [1363.46 sec]: learning rate : 0.000025 loss : 1.001529 +[14:29:36.296] iteration 60000 [1449.91 sec]: learning rate : 0.000006 loss : 0.675476 +[14:29:36.450] save model to model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion/iter_60000.pth +[14:31:02.959] iteration 60100 [1536.58 sec]: learning rate : 0.000013 loss : 0.558823 +[14:32:29.476] iteration 60200 [1623.09 sec]: learning rate : 0.000013 loss : 0.518746 +[14:33:55.926] iteration 60300 [1709.54 sec]: learning rate : 0.000013 loss : 0.898552 +[14:35:22.448] iteration 60400 [1796.07 sec]: learning rate : 0.000013 loss : 0.752587 +[14:35:28.477] Epoch 28 Evaluation: +[14:36:20.256] average MSE: 0.05574697256088257 average PSNR: 28.185083777282145 average SSIM: 0.6031350796064147 +[14:37:40.972] iteration 60500 [80.65 sec]: learning rate : 0.000013 loss : 0.581153 +[14:39:07.445] iteration 60600 [167.13 sec]: learning rate : 0.000013 loss : 0.603024 +[14:40:33.966] iteration 60700 [253.65 sec]: learning rate : 0.000013 loss : 0.434815 +[14:42:00.428] iteration 60800 [340.11 sec]: learning rate : 0.000013 loss : 0.471425 +[14:43:26.977] iteration 60900 [426.66 sec]: learning rate : 0.000013 loss : 0.559749 +[14:44:53.501] iteration 61000 [513.18 sec]: learning rate : 0.000013 loss : 0.716136 +[14:46:19.965] iteration 61100 [599.65 sec]: learning rate : 0.000013 loss : 0.422424 +[14:47:46.474] iteration 61200 [686.16 sec]: learning rate : 0.000013 loss : 0.497832 +[14:49:12.979] iteration 61300 [772.66 sec]: learning rate : 0.000013 loss : 0.810206 +[14:50:39.450] iteration 61400 [859.13 sec]: learning rate : 0.000013 loss : 0.505629 +[14:52:05.954] iteration 61500 [945.64 sec]: learning rate : 0.000013 loss : 0.406566 +[14:53:32.437] iteration 61600 [1032.12 sec]: learning rate : 0.000013 loss : 0.534768 +[14:54:58.906] iteration 61700 [1118.59 sec]: learning rate : 0.000013 loss : 0.434856 +[14:56:25.449] iteration 61800 [1205.13 sec]: learning rate : 0.000013 loss : 0.578056 +[14:57:51.910] iteration 61900 [1291.59 sec]: learning rate : 0.000013 loss : 0.456819 +[14:59:18.445] iteration 62000 [1378.13 sec]: learning rate : 0.000013 loss : 0.363509 +[15:00:44.914] iteration 62100 [1464.60 sec]: learning rate : 0.000013 loss : 0.747022 +[15:02:11.393] iteration 62200 [1551.07 sec]: learning rate : 0.000013 loss : 1.097758 +[15:03:37.919] iteration 62300 [1637.60 sec]: learning rate : 0.000013 loss : 0.496404 +[15:05:04.446] iteration 62400 [1724.13 sec]: learning rate : 0.000013 loss : 0.487567 +[15:06:22.260] Epoch 29 Evaluation: +[15:07:11.525] average MSE: 0.05594076216220856 average PSNR: 28.169286483659107 average SSIM: 0.6031054620432281 +[15:07:20.412] iteration 62500 [8.82 sec]: learning rate : 0.000013 loss : 0.592747 +[15:08:46.940] iteration 62600 [95.35 sec]: learning rate : 0.000013 loss : 0.628585 +[15:10:13.461] iteration 62700 [181.87 sec]: learning rate : 0.000013 loss : 0.384598 +[15:11:39.922] iteration 62800 [268.33 sec]: learning rate : 0.000013 loss : 0.393135 +[15:13:06.464] iteration 62900 [354.88 sec]: learning rate : 0.000013 loss : 0.453917 +[15:14:33.007] iteration 63000 [441.42 sec]: learning rate : 0.000013 loss : 0.881673 +[15:15:59.490] iteration 63100 [527.90 sec]: learning rate : 0.000013 loss : 0.611624 +[15:17:26.020] iteration 63200 [614.43 sec]: learning rate : 0.000013 loss : 0.677593 +[15:18:52.574] iteration 63300 [700.99 sec]: learning rate : 0.000013 loss : 0.854701 +[15:20:19.069] iteration 63400 [787.48 sec]: learning rate : 0.000013 loss : 0.613128 +[15:21:45.627] iteration 63500 [874.04 sec]: learning rate : 0.000013 loss : 0.469158 +[15:23:12.166] iteration 63600 [960.58 sec]: learning rate : 0.000013 loss : 0.376440 +[15:24:38.649] iteration 63700 [1047.06 sec]: learning rate : 0.000013 loss : 0.531595 +[15:26:05.240] iteration 63800 [1133.65 sec]: learning rate : 0.000013 loss : 0.466304 +[15:27:31.718] iteration 63900 [1220.13 sec]: learning rate : 0.000013 loss : 0.462211 +[15:28:58.271] iteration 64000 [1306.68 sec]: learning rate : 0.000013 loss : 0.424574 +[15:30:24.833] iteration 64100 [1393.25 sec]: learning rate : 0.000013 loss : 0.807064 +[15:31:51.314] iteration 64200 [1479.73 sec]: learning rate : 0.000013 loss : 0.832428 +[15:33:17.878] iteration 64300 [1566.29 sec]: learning rate : 0.000013 loss : 0.692171 +[15:34:44.414] iteration 64400 [1652.83 sec]: learning rate : 0.000013 loss : 0.638841 +[15:36:10.901] iteration 64500 [1739.31 sec]: learning rate : 0.000013 loss : 0.497965 +[15:37:14.056] Epoch 30 Evaluation: +[15:38:05.601] average MSE: 0.05589057877659798 average PSNR: 28.178983826035836 average SSIM: 0.6034509867510197 +[15:38:29.219] iteration 64600 [23.56 sec]: learning rate : 0.000013 loss : 0.523350 +[15:39:55.823] iteration 64700 [110.16 sec]: learning rate : 0.000013 loss : 0.764641 +[15:41:22.302] iteration 64800 [196.64 sec]: learning rate : 0.000013 loss : 0.590831 +[15:42:48.854] iteration 64900 [283.19 sec]: learning rate : 0.000013 loss : 0.539176 +[15:44:15.375] iteration 65000 [369.71 sec]: learning rate : 0.000013 loss : 1.132995 +[15:45:41.827] iteration 65100 [456.16 sec]: learning rate : 0.000013 loss : 0.225983 +[15:47:08.361] iteration 65200 [542.70 sec]: learning rate : 0.000013 loss : 0.347091 +[15:48:34.862] iteration 65300 [629.20 sec]: learning rate : 0.000013 loss : 0.477921 +[15:50:01.305] iteration 65400 [715.64 sec]: learning rate : 0.000013 loss : 0.575013 +[15:51:27.867] iteration 65500 [802.20 sec]: learning rate : 0.000013 loss : 0.706436 +[15:52:54.328] iteration 65600 [888.66 sec]: learning rate : 0.000013 loss : 0.642002 +[15:54:20.816] iteration 65700 [975.15 sec]: learning rate : 0.000013 loss : 0.646353 +[15:55:47.327] iteration 65800 [1061.66 sec]: learning rate : 0.000013 loss : 0.736074 +[15:57:13.824] iteration 65900 [1148.16 sec]: learning rate : 0.000013 loss : 0.562219 +[15:58:40.394] iteration 66000 [1234.73 sec]: learning rate : 0.000013 loss : 0.663230 +[16:00:06.843] iteration 66100 [1321.18 sec]: learning rate : 0.000013 loss : 0.482782 +[16:01:33.364] iteration 66200 [1407.70 sec]: learning rate : 0.000013 loss : 0.543910 +[16:02:59.890] iteration 66300 [1494.23 sec]: learning rate : 0.000013 loss : 0.938323 +[16:04:26.351] iteration 66400 [1580.69 sec]: learning rate : 0.000013 loss : 0.458192 +[16:05:52.879] iteration 66500 [1667.21 sec]: learning rate : 0.000013 loss : 0.552881 +[16:07:19.422] iteration 66600 [1753.76 sec]: learning rate : 0.000013 loss : 0.583004 +[16:08:07.833] Epoch 31 Evaluation: +[16:08:57.829] average MSE: 0.05564850568771362 average PSNR: 28.194598701225903 average SSIM: 0.602988278282624 +[16:09:36.149] iteration 66700 [38.26 sec]: learning rate : 0.000013 loss : 0.685287 +[16:11:02.750] iteration 66800 [124.86 sec]: learning rate : 0.000013 loss : 0.793354 +[16:12:29.358] iteration 66900 [211.47 sec]: learning rate : 0.000013 loss : 0.606054 +[16:13:55.929] iteration 67000 [298.04 sec]: learning rate : 0.000013 loss : 0.531837 +[16:15:22.483] iteration 67100 [384.59 sec]: learning rate : 0.000013 loss : 0.367714 +[16:16:49.012] iteration 67200 [471.12 sec]: learning rate : 0.000013 loss : 0.711919 +[16:18:15.626] iteration 67300 [557.73 sec]: learning rate : 0.000013 loss : 0.551802 +[16:19:42.178] iteration 67400 [644.29 sec]: learning rate : 0.000013 loss : 0.485708 +[16:21:08.737] iteration 67500 [730.84 sec]: learning rate : 0.000013 loss : 0.717093 +[16:22:35.342] iteration 67600 [817.45 sec]: learning rate : 0.000013 loss : 0.548958 +[16:24:01.858] iteration 67700 [903.97 sec]: learning rate : 0.000013 loss : 0.618346 +[16:25:28.472] iteration 67800 [990.60 sec]: learning rate : 0.000013 loss : 0.536738 +[16:26:55.100] iteration 67900 [1077.21 sec]: learning rate : 0.000013 loss : 0.750287 +[16:28:21.652] iteration 68000 [1163.76 sec]: learning rate : 0.000013 loss : 0.595791 +[16:29:48.286] iteration 68100 [1250.39 sec]: learning rate : 0.000013 loss : 0.456667 +[16:31:14.838] iteration 68200 [1336.95 sec]: learning rate : 0.000013 loss : 0.630840 +[16:32:41.487] iteration 68300 [1423.60 sec]: learning rate : 0.000013 loss : 0.580166 +[16:34:08.075] iteration 68400 [1510.18 sec]: learning rate : 0.000013 loss : 0.295896 +[16:35:34.658] iteration 68500 [1596.77 sec]: learning rate : 0.000013 loss : 0.559047 +[16:37:01.285] iteration 68600 [1683.39 sec]: learning rate : 0.000013 loss : 0.393001 +[16:38:27.876] iteration 68700 [1769.98 sec]: learning rate : 0.000013 loss : 0.486973 +[16:39:01.610] Epoch 32 Evaluation: +[16:39:51.033] average MSE: 0.05590313673019409 average PSNR: 28.175897637846493 average SSIM: 0.6031504387313992 +[16:40:44.072] iteration 68800 [52.98 sec]: learning rate : 0.000013 loss : 0.479305 +[16:42:10.732] iteration 68900 [139.64 sec]: learning rate : 0.000013 loss : 0.435416 +[16:43:37.345] iteration 69000 [226.25 sec]: learning rate : 0.000013 loss : 0.754315 +[16:45:03.881] iteration 69100 [312.79 sec]: learning rate : 0.000013 loss : 0.372061 +[16:46:30.478] iteration 69200 [399.38 sec]: learning rate : 0.000013 loss : 0.599238 +[16:47:56.999] iteration 69300 [485.90 sec]: learning rate : 0.000013 loss : 0.590950 +[16:49:23.530] iteration 69400 [572.43 sec]: learning rate : 0.000013 loss : 0.381688 +[16:50:50.102] iteration 69500 [659.01 sec]: learning rate : 0.000013 loss : 0.405773 +[16:52:16.650] iteration 69600 [745.56 sec]: learning rate : 0.000013 loss : 0.502511 +[16:53:43.230] iteration 69700 [832.13 sec]: learning rate : 0.000013 loss : 0.733952 +[16:55:09.771] iteration 69800 [918.68 sec]: learning rate : 0.000013 loss : 0.508141 +[16:56:36.361] iteration 69900 [1005.27 sec]: learning rate : 0.000013 loss : 0.709146 +[16:58:02.988] iteration 70000 [1091.89 sec]: learning rate : 0.000013 loss : 0.826527 +[16:59:29.506] iteration 70100 [1178.41 sec]: learning rate : 0.000013 loss : 0.327196 +[17:00:56.077] iteration 70200 [1264.98 sec]: learning rate : 0.000013 loss : 0.664322 +[17:02:22.645] iteration 70300 [1351.55 sec]: learning rate : 0.000013 loss : 0.435639 +[17:03:49.195] iteration 70400 [1438.10 sec]: learning rate : 0.000013 loss : 0.490736 +[17:05:15.788] iteration 70500 [1524.69 sec]: learning rate : 0.000013 loss : 0.625126 +[17:06:42.375] iteration 70600 [1611.28 sec]: learning rate : 0.000013 loss : 0.320439 +[17:08:08.972] iteration 70700 [1697.88 sec]: learning rate : 0.000013 loss : 1.002352 +[17:09:35.587] iteration 70800 [1784.49 sec]: learning rate : 0.000013 loss : 0.449198 +[17:09:54.645] Epoch 33 Evaluation: +[17:10:44.077] average MSE: 0.05576891824603081 average PSNR: 28.1938265213284 average SSIM: 0.6033244153506649 +[17:11:51.789] iteration 70900 [67.65 sec]: learning rate : 0.000013 loss : 0.413963 +[17:13:18.378] iteration 71000 [154.24 sec]: learning rate : 0.000013 loss : 0.738987 +[17:14:44.946] iteration 71100 [240.80 sec]: learning rate : 0.000013 loss : 0.564024 +[17:16:11.464] iteration 71200 [327.32 sec]: learning rate : 0.000013 loss : 0.315234 +[17:17:38.059] iteration 71300 [413.92 sec]: learning rate : 0.000013 loss : 0.839121 +[17:19:04.684] iteration 71400 [500.54 sec]: learning rate : 0.000013 loss : 0.473867 +[17:20:31.239] iteration 71500 [587.10 sec]: learning rate : 0.000013 loss : 0.486883 +[17:21:57.841] iteration 71600 [673.70 sec]: learning rate : 0.000013 loss : 0.650074 +[17:23:24.399] iteration 71700 [760.26 sec]: learning rate : 0.000013 loss : 0.574203 +[17:24:50.964] iteration 71800 [846.82 sec]: learning rate : 0.000013 loss : 0.844860 +[17:26:17.516] iteration 71900 [933.38 sec]: learning rate : 0.000013 loss : 0.715943 +[17:27:44.072] iteration 72000 [1019.93 sec]: learning rate : 0.000013 loss : 0.563014 +[17:29:10.692] iteration 72100 [1106.55 sec]: learning rate : 0.000013 loss : 0.595197 +[17:30:37.306] iteration 72200 [1193.17 sec]: learning rate : 0.000013 loss : 0.521074 +[17:32:03.849] iteration 72300 [1279.71 sec]: learning rate : 0.000013 loss : 0.527515 +[17:33:30.454] iteration 72400 [1366.31 sec]: learning rate : 0.000013 loss : 0.503992 +[17:34:57.046] iteration 72500 [1452.91 sec]: learning rate : 0.000013 loss : 0.608579 +[17:36:23.578] iteration 72600 [1539.44 sec]: learning rate : 0.000013 loss : 0.708220 +[17:37:50.219] iteration 72700 [1626.08 sec]: learning rate : 0.000013 loss : 0.515599 +[17:39:16.814] iteration 72800 [1712.67 sec]: learning rate : 0.000013 loss : 0.791620 +[17:40:43.356] iteration 72900 [1799.22 sec]: learning rate : 0.000013 loss : 0.375221 +[17:40:47.664] Epoch 34 Evaluation: +[17:41:38.725] average MSE: 0.0555657334625721 average PSNR: 28.201332156753494 average SSIM: 0.603303344696349 +[17:43:01.272] iteration 73000 [82.48 sec]: learning rate : 0.000013 loss : 0.574371 +[17:44:27.875] iteration 73100 [169.09 sec]: learning rate : 0.000013 loss : 0.517586 +[17:45:54.450] iteration 73200 [255.66 sec]: learning rate : 0.000013 loss : 0.700112 +[17:47:21.073] iteration 73300 [342.28 sec]: learning rate : 0.000013 loss : 0.372433 +[17:48:47.650] iteration 73400 [428.86 sec]: learning rate : 0.000013 loss : 0.514923 +[17:50:14.228] iteration 73500 [515.44 sec]: learning rate : 0.000013 loss : 0.582750 +[17:51:40.866] iteration 73600 [602.08 sec]: learning rate : 0.000013 loss : 0.292940 +[17:53:07.435] iteration 73700 [688.65 sec]: learning rate : 0.000013 loss : 0.495514 +[17:54:34.095] iteration 73800 [775.31 sec]: learning rate : 0.000013 loss : 0.782746 +[17:56:00.664] iteration 73900 [861.88 sec]: learning rate : 0.000013 loss : 0.733333 +[17:57:27.273] iteration 74000 [948.48 sec]: learning rate : 0.000013 loss : 0.618489 +[17:58:53.861] iteration 74100 [1035.08 sec]: learning rate : 0.000013 loss : 0.439265 +[18:00:20.416] iteration 74200 [1121.63 sec]: learning rate : 0.000013 loss : 0.436792 +[18:01:46.988] iteration 74300 [1208.20 sec]: learning rate : 0.000013 loss : 0.673403 +[18:03:13.595] iteration 74400 [1294.81 sec]: learning rate : 0.000013 loss : 0.523329 +[18:04:40.158] iteration 74500 [1381.37 sec]: learning rate : 0.000013 loss : 0.594553 +[18:06:06.758] iteration 74600 [1467.97 sec]: learning rate : 0.000013 loss : 0.568821 +[18:07:33.316] iteration 74700 [1554.53 sec]: learning rate : 0.000013 loss : 0.586636 +[18:08:59.968] iteration 74800 [1641.18 sec]: learning rate : 0.000013 loss : 0.546479 +[18:10:26.595] iteration 74900 [1727.81 sec]: learning rate : 0.000013 loss : 0.883853 +[18:11:42.757] Epoch 35 Evaluation: +[18:12:32.156] average MSE: 0.05572645738720894 average PSNR: 28.191200068040104 average SSIM: 0.6038442574083834 +[18:12:42.823] iteration 75000 [10.60 sec]: learning rate : 0.000013 loss : 1.219231 +[18:14:09.480] iteration 75100 [97.26 sec]: learning rate : 0.000013 loss : 0.686463 +[18:15:36.084] iteration 75200 [183.86 sec]: learning rate : 0.000013 loss : 0.376049 +[18:17:02.671] iteration 75300 [270.45 sec]: learning rate : 0.000013 loss : 0.585044 +[18:18:29.295] iteration 75400 [357.08 sec]: learning rate : 0.000013 loss : 0.523208 +[18:19:55.973] iteration 75500 [443.75 sec]: learning rate : 0.000013 loss : 0.735513 +[18:21:22.576] iteration 75600 [530.36 sec]: learning rate : 0.000013 loss : 0.703590 +[18:22:49.223] iteration 75700 [617.00 sec]: learning rate : 0.000013 loss : 0.732946 +[18:24:15.845] iteration 75800 [703.65 sec]: learning rate : 0.000013 loss : 0.593827 +[18:25:42.564] iteration 75900 [790.34 sec]: learning rate : 0.000013 loss : 0.601681 +[18:27:09.196] iteration 76000 [876.98 sec]: learning rate : 0.000013 loss : 0.571764 +[18:28:35.794] iteration 76100 [963.58 sec]: learning rate : 0.000013 loss : 0.402085 +[18:30:02.479] iteration 76200 [1050.26 sec]: learning rate : 0.000013 loss : 0.590664 +[18:31:29.092] iteration 76300 [1136.87 sec]: learning rate : 0.000013 loss : 0.531561 +[18:32:55.774] iteration 76400 [1223.55 sec]: learning rate : 0.000013 loss : 0.448103 +[18:34:22.447] iteration 76500 [1310.23 sec]: learning rate : 0.000013 loss : 0.682670 +[18:35:49.041] iteration 76600 [1396.82 sec]: learning rate : 0.000013 loss : 0.652864 +[18:37:15.666] iteration 76700 [1483.45 sec]: learning rate : 0.000013 loss : 0.745813 +[18:38:42.268] iteration 76800 [1570.05 sec]: learning rate : 0.000013 loss : 0.692996 +[18:40:08.881] iteration 76900 [1656.66 sec]: learning rate : 0.000013 loss : 0.624274 +[18:41:35.543] iteration 77000 [1743.32 sec]: learning rate : 0.000013 loss : 0.380772 +[18:42:37.011] Epoch 36 Evaluation: +[18:43:26.946] average MSE: 0.055519770830869675 average PSNR: 28.207398111902364 average SSIM: 0.6035472569930874 +[18:43:52.309] iteration 77100 [25.30 sec]: learning rate : 0.000013 loss : 0.473987 +[18:45:18.988] iteration 77200 [111.98 sec]: learning rate : 0.000013 loss : 0.391022 +[18:46:45.619] iteration 77300 [198.61 sec]: learning rate : 0.000013 loss : 0.664152 +[18:48:12.205] iteration 77400 [285.20 sec]: learning rate : 0.000013 loss : 0.627378 +[18:49:38.848] iteration 77500 [371.84 sec]: learning rate : 0.000013 loss : 0.810408 +[18:51:05.511] iteration 77600 [458.50 sec]: learning rate : 0.000013 loss : 0.788255 +[18:52:32.110] iteration 77700 [545.10 sec]: learning rate : 0.000013 loss : 0.894113 +[18:53:58.750] iteration 77800 [631.74 sec]: learning rate : 0.000013 loss : 0.777961 +[18:55:25.367] iteration 77900 [718.36 sec]: learning rate : 0.000013 loss : 0.682631 +[18:56:51.962] iteration 78000 [804.95 sec]: learning rate : 0.000013 loss : 0.653850 +[18:58:18.593] iteration 78100 [891.58 sec]: learning rate : 0.000013 loss : 0.468537 +[18:59:45.176] iteration 78200 [978.17 sec]: learning rate : 0.000013 loss : 0.476061 +[19:01:11.837] iteration 78300 [1064.83 sec]: learning rate : 0.000013 loss : 0.620662 +[19:02:38.434] iteration 78400 [1151.43 sec]: learning rate : 0.000013 loss : 0.800640 +[19:04:05.033] iteration 78500 [1238.02 sec]: learning rate : 0.000013 loss : 0.443015 +[19:05:31.662] iteration 78600 [1324.65 sec]: learning rate : 0.000013 loss : 0.701694 +[19:06:58.264] iteration 78700 [1411.27 sec]: learning rate : 0.000013 loss : 0.715121 +[19:08:24.830] iteration 78800 [1497.82 sec]: learning rate : 0.000013 loss : 0.561452 +[19:09:51.530] iteration 78900 [1584.52 sec]: learning rate : 0.000013 loss : 0.603398 +[19:11:18.179] iteration 79000 [1671.17 sec]: learning rate : 0.000013 loss : 0.491683 +[19:12:44.767] iteration 79100 [1757.76 sec]: learning rate : 0.000013 loss : 0.461602 +[19:13:31.566] Epoch 37 Evaluation: +[19:14:21.232] average MSE: 0.055369798094034195 average PSNR: 28.224171584984493 average SSIM: 0.6039663581627687 +[19:15:01.294] iteration 79200 [40.00 sec]: learning rate : 0.000013 loss : 0.534041 +[19:16:27.827] iteration 79300 [126.53 sec]: learning rate : 0.000013 loss : 0.501365 +[19:17:54.541] iteration 79400 [213.25 sec]: learning rate : 0.000013 loss : 0.384668 +[19:19:21.159] iteration 79500 [299.86 sec]: learning rate : 0.000013 loss : 0.712430 +[19:20:47.746] iteration 79600 [386.45 sec]: learning rate : 0.000013 loss : 0.577006 +[19:22:14.394] iteration 79700 [473.10 sec]: learning rate : 0.000013 loss : 0.574304 +[19:23:41.056] iteration 79800 [559.76 sec]: learning rate : 0.000013 loss : 0.307254 +[19:25:07.636] iteration 79900 [646.34 sec]: learning rate : 0.000013 loss : 0.500365 +[19:26:34.325] iteration 80000 [733.03 sec]: learning rate : 0.000003 loss : 0.657453 +[19:26:34.485] save model to model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion/iter_80000.pth +[19:28:01.124] iteration 80100 [819.83 sec]: learning rate : 0.000006 loss : 0.461392 +[19:29:27.678] iteration 80200 [906.38 sec]: learning rate : 0.000006 loss : 0.339288 +[19:30:54.306] iteration 80300 [993.01 sec]: learning rate : 0.000006 loss : 0.528792 +[19:32:20.871] iteration 80400 [1079.58 sec]: learning rate : 0.000006 loss : 0.682093 +[19:33:47.510] iteration 80500 [1166.22 sec]: learning rate : 0.000006 loss : 0.555212 +[19:35:14.132] iteration 80600 [1252.84 sec]: learning rate : 0.000006 loss : 0.680776 +[19:36:40.692] iteration 80700 [1339.40 sec]: learning rate : 0.000006 loss : 0.568171 +[19:38:07.318] iteration 80800 [1426.02 sec]: learning rate : 0.000006 loss : 0.446498 +[19:39:33.979] iteration 80900 [1512.68 sec]: learning rate : 0.000006 loss : 0.608394 +[19:41:00.544] iteration 81000 [1599.25 sec]: learning rate : 0.000006 loss : 0.486342 +[19:42:27.124] iteration 81100 [1685.83 sec]: learning rate : 0.000006 loss : 0.716771 +[19:43:53.743] iteration 81200 [1772.51 sec]: learning rate : 0.000006 loss : 0.799922 +[19:44:25.735] Epoch 38 Evaluation: +[19:45:15.226] average MSE: 0.055412787944078445 average PSNR: 28.218379423758353 average SSIM: 0.6038843112423745 +[19:46:09.993] iteration 81300 [54.70 sec]: learning rate : 0.000006 loss : 0.590192 +[19:47:36.650] iteration 81400 [141.36 sec]: learning rate : 0.000006 loss : 0.433876 +[19:49:03.202] iteration 81500 [227.91 sec]: learning rate : 0.000006 loss : 0.844570 +[19:50:29.824] iteration 81600 [314.54 sec]: learning rate : 0.000006 loss : 0.309256 +[19:51:56.414] iteration 81700 [401.13 sec]: learning rate : 0.000006 loss : 0.409680 +[19:53:22.986] iteration 81800 [487.70 sec]: learning rate : 0.000006 loss : 0.556242 +[19:54:49.640] iteration 81900 [574.35 sec]: learning rate : 0.000006 loss : 0.416310 +[19:56:16.257] iteration 82000 [660.97 sec]: learning rate : 0.000006 loss : 0.705743 +[19:57:42.856] iteration 82100 [747.57 sec]: learning rate : 0.000006 loss : 0.449665 +[19:59:09.521] iteration 82200 [834.23 sec]: learning rate : 0.000006 loss : 0.730912 +[20:00:36.099] iteration 82300 [920.81 sec]: learning rate : 0.000006 loss : 0.598102 +[20:02:02.751] iteration 82400 [1007.46 sec]: learning rate : 0.000006 loss : 0.547808 +[20:03:29.417] iteration 82500 [1094.13 sec]: learning rate : 0.000006 loss : 0.490513 +[20:04:56.008] iteration 82600 [1180.72 sec]: learning rate : 0.000006 loss : 0.402567 +[20:06:22.638] iteration 82700 [1267.35 sec]: learning rate : 0.000006 loss : 0.509759 +[20:07:49.210] iteration 82800 [1353.92 sec]: learning rate : 0.000006 loss : 0.613957 +[20:09:15.848] iteration 82900 [1440.56 sec]: learning rate : 0.000006 loss : 0.517702 +[20:10:42.481] iteration 83000 [1527.19 sec]: learning rate : 0.000006 loss : 0.647116 +[20:12:09.080] iteration 83100 [1613.79 sec]: learning rate : 0.000006 loss : 0.421784 +[20:13:35.735] iteration 83200 [1700.45 sec]: learning rate : 0.000006 loss : 0.791377 +[20:15:02.330] iteration 83300 [1787.04 sec]: learning rate : 0.000006 loss : 0.749417 +[20:15:19.661] Epoch 39 Evaluation: +[20:16:09.026] average MSE: 0.055352937430143356 average PSNR: 28.22172652004159 average SSIM: 0.6039448571730875 +[20:17:18.514] iteration 83400 [69.43 sec]: learning rate : 0.000006 loss : 0.312342 +[20:18:45.166] iteration 83500 [156.08 sec]: learning rate : 0.000006 loss : 0.413705 +[20:20:11.754] iteration 83600 [242.67 sec]: learning rate : 0.000006 loss : 0.655057 +[20:21:38.406] iteration 83700 [329.32 sec]: learning rate : 0.000006 loss : 0.723773 +[20:23:05.058] iteration 83800 [415.97 sec]: learning rate : 0.000006 loss : 0.496424 +[20:24:31.654] iteration 83900 [502.56 sec]: learning rate : 0.000006 loss : 0.491815 +[20:25:58.306] iteration 84000 [589.22 sec]: learning rate : 0.000006 loss : 0.502291 +[20:27:24.940] iteration 84100 [675.85 sec]: learning rate : 0.000006 loss : 0.552914 +[20:28:51.561] iteration 84200 [762.47 sec]: learning rate : 0.000006 loss : 0.666022 +[20:30:18.252] iteration 84300 [849.16 sec]: learning rate : 0.000006 loss : 0.677633 +[20:31:44.888] iteration 84400 [935.80 sec]: learning rate : 0.000006 loss : 0.506620 +[20:33:11.486] iteration 84500 [1022.40 sec]: learning rate : 0.000006 loss : 0.340875 +[20:34:38.166] iteration 84600 [1109.08 sec]: learning rate : 0.000006 loss : 0.413161 +[20:36:04.747] iteration 84700 [1195.66 sec]: learning rate : 0.000006 loss : 0.470065 +[20:37:31.434] iteration 84800 [1282.35 sec]: learning rate : 0.000006 loss : 0.561438 +[20:38:58.077] iteration 84900 [1368.99 sec]: learning rate : 0.000006 loss : 0.536467 +[20:40:24.676] iteration 85000 [1455.59 sec]: learning rate : 0.000006 loss : 1.339119 +[20:41:51.290] iteration 85100 [1542.20 sec]: learning rate : 0.000006 loss : 0.578255 +[20:43:17.931] iteration 85200 [1628.84 sec]: learning rate : 0.000006 loss : 0.720506 +[20:44:44.520] iteration 85300 [1715.43 sec]: learning rate : 0.000006 loss : 0.579647 +[20:46:11.166] iteration 85400 [1802.08 sec]: learning rate : 0.000006 loss : 0.637011 +[20:46:13.744] Epoch 40 Evaluation: +[20:47:04.168] average MSE: 0.055449556559324265 average PSNR: 28.21721023838403 average SSIM: 0.6042227731846943 +[20:48:28.351] iteration 85500 [84.12 sec]: learning rate : 0.000006 loss : 0.648560 +[20:49:55.034] iteration 85600 [170.80 sec]: learning rate : 0.000006 loss : 0.797992 +[20:51:21.653] iteration 85700 [257.42 sec]: learning rate : 0.000006 loss : 0.418578 +[20:52:48.240] iteration 85800 [344.01 sec]: learning rate : 0.000006 loss : 0.832134 +[20:54:14.913] iteration 85900 [430.68 sec]: learning rate : 0.000006 loss : 0.240289 +[20:55:41.507] iteration 86000 [517.28 sec]: learning rate : 0.000006 loss : 0.339853 +[20:57:08.125] iteration 86100 [603.90 sec]: learning rate : 0.000006 loss : 0.386072 +[20:58:34.755] iteration 86200 [690.53 sec]: learning rate : 0.000006 loss : 0.650129 +[21:00:01.347] iteration 86300 [777.12 sec]: learning rate : 0.000006 loss : 0.556049 +[21:01:27.973] iteration 86400 [863.74 sec]: learning rate : 0.000006 loss : 0.343548 +[21:02:54.558] iteration 86500 [950.33 sec]: learning rate : 0.000006 loss : 0.605613 +[21:04:21.220] iteration 86600 [1036.99 sec]: learning rate : 0.000006 loss : 0.924431 +[21:05:47.840] iteration 86700 [1123.61 sec]: learning rate : 0.000006 loss : 0.430074 +[21:07:14.435] iteration 86800 [1210.20 sec]: learning rate : 0.000006 loss : 1.148954 +[21:08:41.096] iteration 86900 [1296.87 sec]: learning rate : 0.000006 loss : 0.548650 +[21:10:07.692] iteration 87000 [1383.46 sec]: learning rate : 0.000006 loss : 0.625255 +[21:11:34.324] iteration 87100 [1470.10 sec]: learning rate : 0.000006 loss : 0.460188 +[21:13:00.969] iteration 87200 [1556.74 sec]: learning rate : 0.000006 loss : 0.609283 +[21:14:27.541] iteration 87300 [1643.31 sec]: learning rate : 0.000006 loss : 0.581203 +[21:15:54.154] iteration 87400 [1729.92 sec]: learning rate : 0.000006 loss : 1.056414 +[21:17:08.581] Epoch 41 Evaluation: +[21:18:00.109] average MSE: 0.055575281381607056 average PSNR: 28.20728694492128 average SSIM: 0.6042138348105289 +[21:18:12.638] iteration 87500 [12.47 sec]: learning rate : 0.000006 loss : 0.810903 +[21:19:39.189] iteration 87600 [99.02 sec]: learning rate : 0.000006 loss : 0.447068 +[21:21:05.829] iteration 87700 [185.66 sec]: learning rate : 0.000006 loss : 0.491932 +[21:22:32.429] iteration 87800 [272.26 sec]: learning rate : 0.000006 loss : 0.598496 +[21:23:59.062] iteration 87900 [358.91 sec]: learning rate : 0.000006 loss : 0.532901 +[21:25:25.740] iteration 88000 [445.57 sec]: learning rate : 0.000006 loss : 0.544204 +[21:26:52.367] iteration 88100 [532.19 sec]: learning rate : 0.000006 loss : 0.474634 +[21:28:19.036] iteration 88200 [618.86 sec]: learning rate : 0.000006 loss : 0.608914 +[21:29:45.628] iteration 88300 [705.46 sec]: learning rate : 0.000006 loss : 0.455449 +[21:31:12.292] iteration 88400 [792.12 sec]: learning rate : 0.000006 loss : 0.587916 +[21:32:38.927] iteration 88500 [878.76 sec]: learning rate : 0.000006 loss : 0.502756 +[21:34:05.516] iteration 88600 [965.34 sec]: learning rate : 0.000006 loss : 0.728004 +[21:35:32.170] iteration 88700 [1052.00 sec]: learning rate : 0.000006 loss : 0.650533 +[21:36:58.841] iteration 88800 [1138.67 sec]: learning rate : 0.000006 loss : 0.646238 +[21:38:25.439] iteration 88900 [1225.27 sec]: learning rate : 0.000006 loss : 0.626404 +[21:39:52.091] iteration 89000 [1311.92 sec]: learning rate : 0.000006 loss : 0.607869 +[21:41:18.759] iteration 89100 [1398.59 sec]: learning rate : 0.000006 loss : 0.450869 +[21:42:45.368] iteration 89200 [1485.20 sec]: learning rate : 0.000006 loss : 0.701277 +[21:44:12.068] iteration 89300 [1571.89 sec]: learning rate : 0.000006 loss : 0.559867 +[21:45:38.735] iteration 89400 [1658.56 sec]: learning rate : 0.000006 loss : 0.532936 +[21:47:05.342] iteration 89500 [1745.17 sec]: learning rate : 0.000006 loss : 0.202993 +[21:48:05.134] Epoch 42 Evaluation: +[21:48:55.060] average MSE: 0.05548651143908501 average PSNR: 28.21752449090761 average SSIM: 0.6041013777492691 +[21:49:22.140] iteration 89600 [27.02 sec]: learning rate : 0.000006 loss : 0.386244 +[21:50:48.797] iteration 89700 [113.67 sec]: learning rate : 0.000006 loss : 0.645813 +[21:52:15.391] iteration 89800 [200.27 sec]: learning rate : 0.000006 loss : 0.734795 +[21:53:42.053] iteration 89900 [286.93 sec]: learning rate : 0.000006 loss : 0.308833 +[21:55:08.725] iteration 90000 [373.60 sec]: learning rate : 0.000006 loss : 0.661309 +[21:56:35.339] iteration 90100 [460.22 sec]: learning rate : 0.000006 loss : 0.483884 +[21:58:02.049] iteration 90200 [546.93 sec]: learning rate : 0.000006 loss : 0.653549 +[21:59:28.742] iteration 90300 [633.62 sec]: learning rate : 0.000006 loss : 0.622092 +[22:00:55.377] iteration 90400 [720.26 sec]: learning rate : 0.000006 loss : 0.428379 +[22:02:22.044] iteration 90500 [806.92 sec]: learning rate : 0.000006 loss : 0.537603 +[22:03:48.708] iteration 90600 [893.59 sec]: learning rate : 0.000006 loss : 0.661764 +[22:05:15.342] iteration 90700 [980.22 sec]: learning rate : 0.000006 loss : 0.418077 +[22:06:42.013] iteration 90800 [1066.89 sec]: learning rate : 0.000006 loss : 0.633769 +[22:08:08.715] iteration 90900 [1153.59 sec]: learning rate : 0.000006 loss : 0.589279 +[22:09:35.333] iteration 91000 [1240.21 sec]: learning rate : 0.000006 loss : 0.441004 +[22:11:01.993] iteration 91100 [1326.87 sec]: learning rate : 0.000006 loss : 0.648901 +[22:12:28.725] iteration 91200 [1413.60 sec]: learning rate : 0.000006 loss : 0.887169 +[22:13:55.334] iteration 91300 [1500.21 sec]: learning rate : 0.000006 loss : 0.432734 +[22:15:22.020] iteration 91400 [1586.90 sec]: learning rate : 0.000006 loss : 0.442684 +[22:16:48.634] iteration 91500 [1673.51 sec]: learning rate : 0.000006 loss : 0.648870 +[22:18:15.268] iteration 91600 [1760.15 sec]: learning rate : 0.000006 loss : 0.619774 +[22:19:00.286] Epoch 43 Evaluation: +[22:19:51.861] average MSE: 0.055311840027570724 average PSNR: 28.231590197587654 average SSIM: 0.6044201213437892 +[22:20:33.789] iteration 91700 [41.86 sec]: learning rate : 0.000006 loss : 0.416636 +[22:22:00.361] iteration 91800 [128.43 sec]: learning rate : 0.000006 loss : 0.580758 +[22:23:27.002] iteration 91900 [215.08 sec]: learning rate : 0.000006 loss : 0.522611 +[22:24:53.594] iteration 92000 [301.67 sec]: learning rate : 0.000006 loss : 0.492248 +[22:26:20.301] iteration 92100 [388.37 sec]: learning rate : 0.000006 loss : 0.469964 +[22:27:46.998] iteration 92200 [475.07 sec]: learning rate : 0.000006 loss : 0.601655 +[22:29:13.611] iteration 92300 [561.69 sec]: learning rate : 0.000006 loss : 0.382403 +[22:30:40.263] iteration 92400 [648.34 sec]: learning rate : 0.000006 loss : 0.593479 +[22:32:06.871] iteration 92500 [734.94 sec]: learning rate : 0.000006 loss : 0.561913 +[22:33:33.469] iteration 92600 [821.54 sec]: learning rate : 0.000006 loss : 0.427269 +[22:35:00.122] iteration 92700 [908.20 sec]: learning rate : 0.000006 loss : 0.531767 +[22:36:26.704] iteration 92800 [994.78 sec]: learning rate : 0.000006 loss : 0.724908 +[22:37:53.366] iteration 92900 [1081.44 sec]: learning rate : 0.000006 loss : 0.682651 +[22:39:20.065] iteration 93000 [1168.14 sec]: learning rate : 0.000006 loss : 0.478246 +[22:40:46.690] iteration 93100 [1254.76 sec]: learning rate : 0.000006 loss : 0.886133 +[22:42:13.334] iteration 93200 [1341.41 sec]: learning rate : 0.000006 loss : 0.789745 +[22:43:39.920] iteration 93300 [1427.99 sec]: learning rate : 0.000006 loss : 0.453592 +[22:45:06.533] iteration 93400 [1514.61 sec]: learning rate : 0.000006 loss : 0.580272 +[22:46:33.199] iteration 93500 [1601.27 sec]: learning rate : 0.000006 loss : 0.491382 +[22:47:59.796] iteration 93600 [1687.87 sec]: learning rate : 0.000006 loss : 0.547160 +[22:49:26.441] iteration 93700 [1774.51 sec]: learning rate : 0.000006 loss : 0.576578 +[22:49:56.726] Epoch 44 Evaluation: +[22:50:48.115] average MSE: 0.05530009791254997 average PSNR: 28.226122104190633 average SSIM: 0.6036984988254975 +[22:51:44.732] iteration 93800 [56.55 sec]: learning rate : 0.000006 loss : 0.604688 +[22:53:11.291] iteration 93900 [143.11 sec]: learning rate : 0.000006 loss : 0.443973 +[22:54:37.888] iteration 94000 [229.71 sec]: learning rate : 0.000006 loss : 0.275209 +[22:56:04.461] iteration 94100 [316.28 sec]: learning rate : 0.000006 loss : 0.621099 +[22:57:31.140] iteration 94200 [402.96 sec]: learning rate : 0.000006 loss : 0.527014 +[22:58:57.811] iteration 94300 [489.63 sec]: learning rate : 0.000006 loss : 0.616924 +[23:00:24.401] iteration 94400 [576.22 sec]: learning rate : 0.000006 loss : 0.927693 +[23:01:51.055] iteration 94500 [662.88 sec]: learning rate : 0.000006 loss : 0.447667 +[23:03:17.644] iteration 94600 [749.47 sec]: learning rate : 0.000006 loss : 0.870831 +[23:04:44.280] iteration 94700 [836.10 sec]: learning rate : 0.000006 loss : 0.532043 +[23:06:10.923] iteration 94800 [922.75 sec]: learning rate : 0.000006 loss : 0.485287 +[23:07:37.511] iteration 94900 [1009.33 sec]: learning rate : 0.000006 loss : 0.637572 +[23:09:04.200] iteration 95000 [1096.02 sec]: learning rate : 0.000006 loss : 0.664566 +[23:10:30.789] iteration 95100 [1182.61 sec]: learning rate : 0.000006 loss : 0.462268 +[23:11:57.472] iteration 95200 [1269.29 sec]: learning rate : 0.000006 loss : 0.521645 +[23:13:24.163] iteration 95300 [1355.98 sec]: learning rate : 0.000006 loss : 0.425991 +[23:14:50.763] iteration 95400 [1442.61 sec]: learning rate : 0.000006 loss : 0.501234 +[23:16:17.419] iteration 95500 [1529.24 sec]: learning rate : 0.000006 loss : 0.525054 +[23:17:44.103] iteration 95600 [1615.92 sec]: learning rate : 0.000006 loss : 0.607752 +[23:19:10.693] iteration 95700 [1702.51 sec]: learning rate : 0.000006 loss : 0.838152 +[23:20:37.361] iteration 95800 [1789.18 sec]: learning rate : 0.000006 loss : 0.745968 +[23:20:52.924] Epoch 45 Evaluation: +[23:21:42.343] average MSE: 0.05538278445601463 average PSNR: 28.22224848323496 average SSIM: 0.6042506044024955 +[23:22:53.672] iteration 95900 [71.27 sec]: learning rate : 0.000006 loss : 0.356444 +[23:24:20.247] iteration 96000 [157.84 sec]: learning rate : 0.000006 loss : 0.681741 +[23:25:46.916] iteration 96100 [244.51 sec]: learning rate : 0.000006 loss : 0.597013 +[23:27:13.519] iteration 96200 [331.11 sec]: learning rate : 0.000006 loss : 0.529531 +[23:28:40.154] iteration 96300 [417.75 sec]: learning rate : 0.000006 loss : 0.599271 +[23:30:06.834] iteration 96400 [504.43 sec]: learning rate : 0.000006 loss : 0.698090 +[23:31:33.442] iteration 96500 [591.04 sec]: learning rate : 0.000006 loss : 0.472729 +[23:33:00.110] iteration 96600 [677.70 sec]: learning rate : 0.000006 loss : 0.630001 +[23:34:26.759] iteration 96700 [764.35 sec]: learning rate : 0.000006 loss : 0.471789 +[23:35:53.369] iteration 96800 [850.96 sec]: learning rate : 0.000006 loss : 0.440141 +[23:37:20.026] iteration 96900 [937.62 sec]: learning rate : 0.000006 loss : 0.627044 +[23:38:46.681] iteration 97000 [1024.28 sec]: learning rate : 0.000006 loss : 0.801878 +[23:40:13.316] iteration 97100 [1110.91 sec]: learning rate : 0.000006 loss : 0.675188 +[23:41:40.004] iteration 97200 [1197.62 sec]: learning rate : 0.000006 loss : 0.622306 +[23:43:06.667] iteration 97300 [1284.32 sec]: learning rate : 0.000006 loss : 0.748108 +[23:44:33.278] iteration 97400 [1370.87 sec]: learning rate : 0.000006 loss : 0.659703 +[23:45:59.952] iteration 97500 [1457.55 sec]: learning rate : 0.000006 loss : 0.624419 +[23:47:26.545] iteration 97600 [1544.14 sec]: learning rate : 0.000006 loss : 0.712648 +[23:48:53.219] iteration 97700 [1630.81 sec]: learning rate : 0.000006 loss : 0.901880 +[23:50:19.863] iteration 97800 [1717.46 sec]: learning rate : 0.000006 loss : 0.567070 +[23:51:46.450] iteration 97900 [1804.04 sec]: learning rate : 0.000006 loss : 0.793731 +[23:51:47.286] Epoch 46 Evaluation: +[23:52:38.090] average MSE: 0.05527488514780998 average PSNR: 28.2305905458651 average SSIM: 0.6042656034191775 +[23:54:04.175] iteration 98000 [86.02 sec]: learning rate : 0.000006 loss : 0.503267 +[23:55:30.807] iteration 98100 [172.66 sec]: learning rate : 0.000006 loss : 0.676058 +[23:56:57.380] iteration 98200 [259.23 sec]: learning rate : 0.000006 loss : 0.474281 +[23:58:24.013] iteration 98300 [345.86 sec]: learning rate : 0.000006 loss : 0.488607 +[23:59:50.658] iteration 98400 [432.50 sec]: learning rate : 0.000006 loss : 0.349760 +[00:01:17.261] iteration 98500 [519.11 sec]: learning rate : 0.000006 loss : 0.443909 +[00:02:43.901] iteration 98600 [605.75 sec]: learning rate : 0.000006 loss : 0.466257 +[00:04:10.583] iteration 98700 [692.43 sec]: learning rate : 0.000006 loss : 0.607537 +[00:05:37.187] iteration 98800 [779.03 sec]: learning rate : 0.000006 loss : 0.535600 +[00:07:03.861] iteration 98900 [865.71 sec]: learning rate : 0.000006 loss : 0.489714 +[00:08:30.441] iteration 99000 [952.29 sec]: learning rate : 0.000006 loss : 0.955745 +[00:09:57.097] iteration 99100 [1038.94 sec]: learning rate : 0.000006 loss : 0.581975 +[00:11:23.742] iteration 99200 [1125.59 sec]: learning rate : 0.000006 loss : 0.493614 +[00:12:50.331] iteration 99300 [1212.18 sec]: learning rate : 0.000006 loss : 0.529232 +[00:14:16.988] iteration 99400 [1298.84 sec]: learning rate : 0.000006 loss : 0.553029 +[00:15:43.682] iteration 99500 [1385.53 sec]: learning rate : 0.000006 loss : 0.795860 +[00:17:10.276] iteration 99600 [1472.12 sec]: learning rate : 0.000006 loss : 0.467312 +[00:18:36.950] iteration 99700 [1558.80 sec]: learning rate : 0.000006 loss : 0.851902 +[00:20:03.614] iteration 99800 [1645.46 sec]: learning rate : 0.000006 loss : 0.543774 +[00:21:30.156] iteration 99900 [1732.00 sec]: learning rate : 0.000006 loss : 0.523340 +[00:22:42.843] Epoch 47 Evaluation: +[00:23:34.409] average MSE: 0.055309951305389404 average PSNR: 28.22680777496259 average SSIM: 0.6043843787088766 +[00:23:48.497] iteration 100000 [14.02 sec]: learning rate : 0.000002 loss : 0.495827 +[00:23:48.659] save model to model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion/iter_100000.pth +[00:23:49.498] Epoch 48 Evaluation: +[00:24:40.669] average MSE: 0.05530574917793274 average PSNR: 28.22215155458237 average SSIM: 0.6037907832671158 +[00:24:40.943] save model to model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion/iter_100000.pth diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion/log/events.out.tfevents.1752648095.GCRSANDBOX133 b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion/log/events.out.tfevents.1752648095.GCRSANDBOX133 new file mode 100644 index 0000000000000000000000000000000000000000..a56741bc632b2f44f6ef05da72d7e3c42120396a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time_no_distortion/log/events.out.tfevents.1752648095.GCRSANDBOX133 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:923b164b741de9bab5fb49bf42d7ebd23c17726e10e0a996d0658445a8c4115b +size 40 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/best_checkpoint.pth b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/best_checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..2505b1b204dbe81df6e3c91fdac2ad13a9984f18 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/best_checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f6eb99212d40acde82a98ce5574fbd87efb4f40eab619ed18baee22d0fbb2d3 +size 56614874 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/log.txt b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..9f7dd4f84a78b1d6c73d917c6e8fb01f06677bbf --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/log.txt @@ -0,0 +1,1367 @@ +[20:34:02.636] Namespace(root_path='/home/v-qichen3/MRI_recon/data/m4raw', MRIDOWN='4X', low_field_SNR=0, phase='train', gpu='1', exp='FSMNet_m4raw_4x_lr5e-4', max_iterations=100000, batch_size=4, base_lr=0.0005, seed=1337, resume=None, relation_consistency='False', clip_grad='True', norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='random', CENTER_FRACTIONS=[0.08], ACCELERATIONS=[4], num_timesteps=10, image_size=240, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[20:40:30.125] iteration 100 [48.96 sec]: learning rate : 0.000500 loss : 0.508609 +[20:41:17.736] iteration 200 [96.56 sec]: learning rate : 0.000500 loss : 0.462476 +[20:42:05.368] iteration 300 [144.19 sec]: learning rate : 0.000500 loss : 0.449851 +[20:42:53.313] iteration 400 [192.13 sec]: learning rate : 0.000500 loss : 0.492846 +[20:43:40.837] iteration 500 [239.66 sec]: learning rate : 0.000500 loss : 0.454343 +[20:44:17.055] Epoch 0 Evaluation: +[20:45:44.135] average MSE: 0.05162555729223008 average PSNR: 25.915480416408524 average SSIM: 0.701826723738366 +[20:45:55.767] iteration 600 [11.61 sec]: learning rate : 0.000500 loss : 0.533740 +[20:46:43.430] iteration 700 [59.27 sec]: learning rate : 0.000500 loss : 0.340977 +[20:47:30.953] iteration 800 [106.79 sec]: learning rate : 0.000500 loss : 0.400986 +[20:48:18.910] iteration 900 [154.75 sec]: learning rate : 0.000500 loss : 0.568610 +[20:49:06.977] iteration 1000 [202.82 sec]: learning rate : 0.000500 loss : 0.547366 +[20:49:54.590] iteration 1100 [250.43 sec]: learning rate : 0.000500 loss : 0.529727 +[20:50:19.965] Epoch 1 Evaluation: +[20:51:47.566] average MSE: 0.04623548664848299 average PSNR: 26.38799940867835 average SSIM: 0.7087946124553622 +[20:52:10.857] iteration 1200 [23.27 sec]: learning rate : 0.000500 loss : 0.431724 +[20:52:58.542] iteration 1300 [70.95 sec]: learning rate : 0.000500 loss : 0.393340 +[20:53:46.232] iteration 1400 [118.64 sec]: learning rate : 0.000500 loss : 0.381093 +[20:54:34.373] iteration 1500 [166.78 sec]: learning rate : 0.000500 loss : 0.352314 +[20:55:22.520] iteration 1600 [214.93 sec]: learning rate : 0.000500 loss : 0.391696 +[20:56:10.254] iteration 1700 [262.67 sec]: learning rate : 0.000500 loss : 0.370115 +[20:56:23.605] Epoch 2 Evaluation: +[20:57:50.072] average MSE: 0.04703218156748821 average PSNR: 26.311590829370417 average SSIM: 0.7020418176372041 +[20:58:24.634] iteration 1800 [34.54 sec]: learning rate : 0.000500 loss : 0.581599 +[20:59:12.530] iteration 1900 [82.43 sec]: learning rate : 0.000500 loss : 0.381155 +[21:00:00.935] iteration 2000 [130.84 sec]: learning rate : 0.000500 loss : 0.439304 +[21:00:48.668] iteration 2100 [178.57 sec]: learning rate : 0.000500 loss : 0.341588 +[21:01:36.363] iteration 2200 [226.27 sec]: learning rate : 0.000500 loss : 0.396294 +[21:02:24.142] iteration 2300 [274.05 sec]: learning rate : 0.000500 loss : 0.294507 +[21:02:26.059] Epoch 3 Evaluation: +[21:03:57.719] average MSE: 0.04745225435464163 average PSNR: 26.273226324678518 average SSIM: 0.6992689769788804 +[21:04:43.854] iteration 2400 [46.11 sec]: learning rate : 0.000500 loss : 0.509557 +[21:05:32.214] iteration 2500 [94.47 sec]: learning rate : 0.000500 loss : 0.402084 +[21:06:19.972] iteration 2600 [142.23 sec]: learning rate : 0.000500 loss : 0.469918 +[21:07:07.579] iteration 2700 [189.84 sec]: learning rate : 0.000500 loss : 0.425484 +[21:07:55.269] iteration 2800 [237.53 sec]: learning rate : 0.000500 loss : 0.385821 +[21:08:33.960] Epoch 4 Evaluation: +[21:09:59.853] average MSE: 0.04796798689485013 average PSNR: 26.22094378082784 average SSIM: 0.6910122035633468 +[21:10:09.783] iteration 2900 [9.91 sec]: learning rate : 0.000500 loss : 0.337644 +[21:10:58.232] iteration 3000 [58.36 sec]: learning rate : 0.000500 loss : 0.315328 +[21:11:45.934] iteration 3100 [106.06 sec]: learning rate : 0.000500 loss : 0.418541 +[21:12:33.556] iteration 3200 [153.68 sec]: learning rate : 0.000500 loss : 0.412900 +[21:13:21.262] iteration 3300 [201.39 sec]: learning rate : 0.000500 loss : 0.465277 +[21:14:08.985] iteration 3400 [249.11 sec]: learning rate : 0.000500 loss : 0.434243 +[21:14:35.644] Epoch 5 Evaluation: +[21:16:02.452] average MSE: 0.044876700869925 average PSNR: 26.5169594465646 average SSIM: 0.7065428936752562 +[21:16:24.504] iteration 3500 [22.03 sec]: learning rate : 0.000500 loss : 0.491870 +[21:17:12.685] iteration 3600 [70.21 sec]: learning rate : 0.000500 loss : 0.355166 +[21:18:00.681] iteration 3700 [118.21 sec]: learning rate : 0.000500 loss : 0.417412 +[21:18:48.277] iteration 3800 [165.80 sec]: learning rate : 0.000500 loss : 0.443904 +[21:19:35.932] iteration 3900 [213.46 sec]: learning rate : 0.000500 loss : 0.433351 +[21:20:23.582] iteration 4000 [261.11 sec]: learning rate : 0.000500 loss : 0.467498 +[21:20:38.799] Epoch 6 Evaluation: +[21:22:09.919] average MSE: 0.05404674027356621 average PSNR: 25.695872941123067 average SSIM: 0.664894979330106 +[21:22:42.540] iteration 4100 [32.60 sec]: learning rate : 0.000500 loss : 0.378066 +[21:23:30.369] iteration 4200 [80.43 sec]: learning rate : 0.000500 loss : 0.274678 +[21:24:18.114] iteration 4300 [128.17 sec]: learning rate : 0.000500 loss : 0.396430 +[21:25:05.765] iteration 4400 [175.82 sec]: learning rate : 0.000500 loss : 0.324796 +[21:25:53.912] iteration 4500 [223.97 sec]: learning rate : 0.000500 loss : 0.379425 +[21:26:42.125] iteration 4600 [272.18 sec]: learning rate : 0.000500 loss : 0.439478 +[21:26:45.947] Epoch 7 Evaluation: +[21:28:17.426] average MSE: 0.059398788826984114 average PSNR: 25.287873431158147 average SSIM: 0.6357055720254421 +[21:29:01.692] iteration 4700 [44.24 sec]: learning rate : 0.000500 loss : 0.354961 +[21:29:49.528] iteration 4800 [92.08 sec]: learning rate : 0.000500 loss : 0.370665 +[21:30:37.166] iteration 4900 [139.72 sec]: learning rate : 0.000500 loss : 0.423517 +[21:31:24.936] iteration 5000 [187.49 sec]: learning rate : 0.000500 loss : 0.321320 +[21:32:12.730] iteration 5100 [235.28 sec]: learning rate : 0.000500 loss : 0.300584 +[21:32:52.731] Epoch 8 Evaluation: +[21:34:24.118] average MSE: 0.05788884292075888 average PSNR: 25.400282335870685 average SSIM: 0.6457146294282282 +[21:34:31.946] iteration 5200 [7.80 sec]: learning rate : 0.000500 loss : 0.317447 +[21:35:19.874] iteration 5300 [55.73 sec]: learning rate : 0.000500 loss : 0.311694 +[21:36:07.484] iteration 5400 [103.36 sec]: learning rate : 0.000500 loss : 0.507216 +[21:36:55.004] iteration 5500 [150.86 sec]: learning rate : 0.000500 loss : 0.372041 +[21:37:42.658] iteration 5600 [198.52 sec]: learning rate : 0.000500 loss : 0.333808 +[21:38:30.188] iteration 5700 [246.05 sec]: learning rate : 0.000500 loss : 0.364874 +[21:38:59.283] Epoch 9 Evaluation: +[21:40:25.707] average MSE: 0.06400191271840754 average PSNR: 24.955257133309534 average SSIM: 0.6160075936092058 +[21:40:44.930] iteration 5800 [19.20 sec]: learning rate : 0.000500 loss : 0.411179 +[21:41:32.646] iteration 5900 [66.91 sec]: learning rate : 0.000500 loss : 0.456154 +[21:42:20.183] iteration 6000 [114.45 sec]: learning rate : 0.000500 loss : 0.392437 +[21:43:07.805] iteration 6100 [162.07 sec]: learning rate : 0.000500 loss : 0.368416 +[21:43:55.896] iteration 6200 [210.16 sec]: learning rate : 0.000500 loss : 0.413295 +[21:44:43.724] iteration 6300 [257.99 sec]: learning rate : 0.000500 loss : 0.401703 +[21:45:00.793] Epoch 10 Evaluation: +[21:46:26.915] average MSE: 0.047411742343816544 average PSNR: 26.277344892333392 average SSIM: 0.6989279745444752 +[21:46:57.593] iteration 6400 [30.66 sec]: learning rate : 0.000500 loss : 0.427407 +[21:47:44.949] iteration 6500 [78.01 sec]: learning rate : 0.000500 loss : 0.417619 +[21:48:33.012] iteration 6600 [126.08 sec]: learning rate : 0.000500 loss : 0.455869 +[21:49:20.559] iteration 6700 [173.62 sec]: learning rate : 0.000500 loss : 0.372738 +[21:50:08.373] iteration 6800 [221.44 sec]: learning rate : 0.000500 loss : 0.471824 +[21:50:55.842] iteration 6900 [268.90 sec]: learning rate : 0.000500 loss : 0.355869 +[21:51:01.548] Epoch 11 Evaluation: +[21:52:27.458] average MSE: 0.04747274486127635 average PSNR: 26.27383987041925 average SSIM: 0.6925152322397232 +[21:53:10.075] iteration 7000 [42.59 sec]: learning rate : 0.000500 loss : 0.419182 +[21:53:57.500] iteration 7100 [90.02 sec]: learning rate : 0.000500 loss : 0.358579 +[21:54:45.026] iteration 7200 [137.55 sec]: learning rate : 0.000500 loss : 0.416770 +[21:55:33.286] iteration 7300 [185.81 sec]: learning rate : 0.000500 loss : 0.318997 +[21:56:20.686] iteration 7400 [233.20 sec]: learning rate : 0.000500 loss : 0.475418 +[21:57:02.484] Epoch 12 Evaluation: +[21:58:28.808] average MSE: 0.045695144438410704 average PSNR: 26.44636301652352 average SSIM: 0.7186597416810642 +[21:58:34.701] iteration 7500 [5.87 sec]: learning rate : 0.000500 loss : 0.390953 +[21:59:22.241] iteration 7600 [53.41 sec]: learning rate : 0.000500 loss : 0.270059 +[22:00:09.576] iteration 7700 [100.75 sec]: learning rate : 0.000500 loss : 0.258201 +[22:00:57.041] iteration 7800 [148.21 sec]: learning rate : 0.000500 loss : 0.298114 +[22:01:44.989] iteration 7900 [196.16 sec]: learning rate : 0.000500 loss : 0.478169 +[22:02:32.878] iteration 8000 [244.05 sec]: learning rate : 0.000500 loss : 0.400374 +[22:03:03.657] Epoch 13 Evaluation: +[22:04:30.854] average MSE: 0.05114772097288918 average PSNR: 25.94171865745492 average SSIM: 0.6755467246334047 +[22:04:48.137] iteration 8100 [17.26 sec]: learning rate : 0.000500 loss : 0.372578 +[22:05:35.676] iteration 8200 [64.80 sec]: learning rate : 0.000500 loss : 0.376690 +[22:06:23.200] iteration 8300 [112.32 sec]: learning rate : 0.000500 loss : 0.343565 +[22:07:11.447] iteration 8400 [160.57 sec]: learning rate : 0.000500 loss : 0.314614 +[22:07:59.037] iteration 8500 [208.16 sec]: learning rate : 0.000500 loss : 0.375012 +[22:08:46.702] iteration 8600 [255.82 sec]: learning rate : 0.000500 loss : 0.300907 +[22:09:05.741] Epoch 14 Evaluation: +[22:10:38.581] average MSE: 0.044557623296700447 average PSNR: 26.57117390702633 average SSIM: 0.7215790361745643 +[22:11:07.934] iteration 8700 [29.33 sec]: learning rate : 0.000500 loss : 0.355831 +[22:11:55.298] iteration 8800 [76.70 sec]: learning rate : 0.000500 loss : 0.370198 +[22:12:42.879] iteration 8900 [124.28 sec]: learning rate : 0.000500 loss : 0.356648 +[22:13:30.371] iteration 9000 [171.77 sec]: learning rate : 0.000500 loss : 0.415546 +[22:14:17.806] iteration 9100 [219.20 sec]: learning rate : 0.000500 loss : 0.367387 +[22:15:05.296] iteration 9200 [266.69 sec]: learning rate : 0.000500 loss : 0.350878 +[22:15:12.888] Epoch 15 Evaluation: +[22:16:43.488] average MSE: 0.04640158534856622 average PSNR: 26.36603179959617 average SSIM: 0.700837308599275 +[22:17:23.767] iteration 9300 [40.26 sec]: learning rate : 0.000500 loss : 0.319427 +[22:18:12.483] iteration 9400 [88.97 sec]: learning rate : 0.000500 loss : 0.463465 +[22:19:00.036] iteration 9500 [136.53 sec]: learning rate : 0.000500 loss : 0.415258 +[22:19:48.031] iteration 9600 [184.52 sec]: learning rate : 0.000500 loss : 0.445281 +[22:20:35.754] iteration 9700 [232.24 sec]: learning rate : 0.000500 loss : 0.413985 +[22:21:19.601] Epoch 16 Evaluation: +[22:22:50.835] average MSE: 0.043256650888796905 average PSNR: 26.681521633021564 average SSIM: 0.7182078569602738 +[22:22:54.886] iteration 9800 [4.03 sec]: learning rate : 0.000500 loss : 0.439550 +[22:23:43.036] iteration 9900 [52.18 sec]: learning rate : 0.000500 loss : 0.351115 +[22:24:30.612] iteration 10000 [99.75 sec]: learning rate : 0.000500 loss : 0.416228 +[22:25:18.135] iteration 10100 [147.28 sec]: learning rate : 0.000500 loss : 0.333078 +[22:26:06.089] iteration 10200 [195.23 sec]: learning rate : 0.000500 loss : 0.322853 +[22:26:53.658] iteration 10300 [242.80 sec]: learning rate : 0.000500 loss : 0.378802 +[22:27:25.947] Epoch 17 Evaluation: +[22:28:52.543] average MSE: 0.044530955360298254 average PSNR: 26.55939996569245 average SSIM: 0.7160158881733754 +[22:29:08.712] iteration 10400 [16.15 sec]: learning rate : 0.000500 loss : 0.697655 +[22:29:56.095] iteration 10500 [63.53 sec]: learning rate : 0.000500 loss : 0.448316 +[22:30:43.570] iteration 10600 [111.00 sec]: learning rate : 0.000500 loss : 0.384037 +[22:31:30.927] iteration 10700 [158.36 sec]: learning rate : 0.000500 loss : 0.371846 +[22:32:18.394] iteration 10800 [205.83 sec]: learning rate : 0.000500 loss : 0.425716 +[22:33:06.487] iteration 10900 [253.92 sec]: learning rate : 0.000500 loss : 0.354244 +[22:33:27.391] Epoch 18 Evaluation: +[22:35:01.239] average MSE: 0.035920660587356434 average PSNR: 27.48423696292679 average SSIM: 0.7446411713988145 +[22:35:28.148] iteration 11000 [26.89 sec]: learning rate : 0.000500 loss : 0.393125 +[22:36:15.867] iteration 11100 [74.61 sec]: learning rate : 0.000500 loss : 0.419326 +[22:37:03.524] iteration 11200 [122.26 sec]: learning rate : 0.000500 loss : 0.356785 +[22:37:51.442] iteration 11300 [170.18 sec]: learning rate : 0.000500 loss : 0.350499 +[22:38:39.098] iteration 11400 [217.84 sec]: learning rate : 0.000500 loss : 0.365562 +[22:39:26.706] iteration 11500 [265.45 sec]: learning rate : 0.000500 loss : 0.432862 +[22:39:36.222] Epoch 19 Evaluation: +[22:41:05.646] average MSE: 0.040275949299166246 average PSNR: 26.98121355126364 average SSIM: 0.7287842703305675 +[22:41:43.867] iteration 11600 [38.20 sec]: learning rate : 0.000500 loss : 20.252167 +[22:42:31.516] iteration 11700 [85.85 sec]: learning rate : 0.000500 loss : 0.526378 +[22:43:19.116] iteration 11800 [133.45 sec]: learning rate : 0.000500 loss : 0.462263 +[22:44:06.642] iteration 11900 [180.97 sec]: learning rate : 0.000500 loss : 0.390181 +[22:44:54.269] iteration 12000 [228.60 sec]: learning rate : 0.000500 loss : 0.493395 +[22:45:39.855] Epoch 20 Evaluation: +[22:47:11.949] average MSE: 0.04893774497110522 average PSNR: 26.149193991320686 average SSIM: 0.7087406382261923 +[22:47:14.066] iteration 12100 [2.11 sec]: learning rate : 0.000500 loss : 0.460946 +[22:48:02.191] iteration 12200 [50.22 sec]: learning rate : 0.000500 loss : 0.373821 +[22:48:49.745] iteration 12300 [97.77 sec]: learning rate : 0.000500 loss : 0.404700 +[22:49:37.116] iteration 12400 [145.14 sec]: learning rate : 0.000500 loss : 0.493071 +[22:50:24.698] iteration 12500 [192.73 sec]: learning rate : 0.000500 loss : 0.378403 +[22:51:12.197] iteration 12600 [240.22 sec]: learning rate : 0.000500 loss : 0.344331 +[22:51:46.784] Epoch 21 Evaluation: +[22:53:13.997] average MSE: 0.04921373020069325 average PSNR: 26.122255561871537 average SSIM: 0.7088167591420502 +[22:53:27.503] iteration 12700 [13.48 sec]: learning rate : 0.000500 loss : 0.322239 +[22:54:15.152] iteration 12800 [61.13 sec]: learning rate : 0.000500 loss : 0.397973 +[22:55:02.716] iteration 12900 [108.69 sec]: learning rate : 0.000500 loss : 0.448921 +[22:55:50.959] iteration 13000 [156.94 sec]: learning rate : 0.000500 loss : 0.504988 +[22:56:38.632] iteration 13100 [204.61 sec]: learning rate : 0.000500 loss : 0.468933 +[22:57:26.677] iteration 13200 [252.66 sec]: learning rate : 0.000500 loss : 0.623801 +[22:57:49.449] Epoch 22 Evaluation: +[22:59:15.696] average MSE: 0.04915187324604143 average PSNR: 26.125486275138744 average SSIM: 0.7083338923262188 +[22:59:40.594] iteration 13300 [24.87 sec]: learning rate : 0.000500 loss : 0.523690 +[23:00:28.138] iteration 13400 [72.42 sec]: learning rate : 0.000500 loss : 0.539807 +[23:01:15.503] iteration 13500 [119.78 sec]: learning rate : 0.000500 loss : 0.468491 +[23:02:02.972] iteration 13600 [167.25 sec]: learning rate : 0.000500 loss : 0.486524 +[23:02:50.988] iteration 13700 [215.27 sec]: learning rate : 0.000500 loss : 0.454575 +[23:03:39.071] iteration 13800 [263.35 sec]: learning rate : 0.000500 loss : 0.396489 +[23:03:50.624] Epoch 23 Evaluation: +[23:05:22.247] average MSE: 0.049392418492757564 average PSNR: 26.102130438545267 average SSIM: 0.7065601876784217 +[23:05:58.524] iteration 13900 [36.25 sec]: learning rate : 0.000500 loss : 0.505843 +[23:06:46.123] iteration 14000 [83.85 sec]: learning rate : 0.000500 loss : 0.533906 +[23:07:33.582] iteration 14100 [131.31 sec]: learning rate : 0.000500 loss : 0.531753 +[23:08:21.573] iteration 14200 [179.30 sec]: learning rate : 0.000500 loss : 0.446960 +[23:09:09.119] iteration 14300 [226.85 sec]: learning rate : 0.000500 loss : 0.431583 +[23:09:56.547] iteration 14400 [274.27 sec]: learning rate : 0.000500 loss : 0.484896 +[23:09:56.585] Epoch 24 Evaluation: +[23:11:26.298] average MSE: 0.049332357187020344 average PSNR: 26.10664758101619 average SSIM: 0.7068137285779749 +[23:12:14.331] iteration 14500 [48.01 sec]: learning rate : 0.000500 loss : 0.445259 +[23:13:01.865] iteration 14600 [95.54 sec]: learning rate : 0.000500 loss : 0.408971 +[23:13:49.978] iteration 14700 [143.66 sec]: learning rate : 0.000500 loss : 0.472063 +[23:14:37.450] iteration 14800 [191.13 sec]: learning rate : 0.000500 loss : 0.449486 +[23:15:25.014] iteration 14900 [238.70 sec]: learning rate : 0.000500 loss : 0.522242 +[23:16:01.252] Epoch 25 Evaluation: +[23:17:30.236] average MSE: 0.049655903381536866 average PSNR: 26.077677806198118 average SSIM: 0.7059718452846242 +[23:17:41.836] iteration 15000 [11.58 sec]: learning rate : 0.000500 loss : 0.395324 +[23:18:29.734] iteration 15100 [59.47 sec]: learning rate : 0.000500 loss : 0.442873 +[23:19:17.309] iteration 15200 [107.07 sec]: learning rate : 0.000500 loss : 0.415858 +[23:20:05.512] iteration 15300 [155.25 sec]: learning rate : 0.000500 loss : 0.410453 +[23:20:53.029] iteration 15400 [202.77 sec]: learning rate : 0.000500 loss : 0.421041 +[23:21:40.424] iteration 15500 [250.16 sec]: learning rate : 0.000500 loss : 0.514647 +[23:22:05.207] Epoch 26 Evaluation: +[23:23:32.273] average MSE: 0.04936045991483521 average PSNR: 26.103655613973178 average SSIM: 0.706504059635713 +[23:23:55.235] iteration 15600 [22.94 sec]: learning rate : 0.000500 loss : 0.548277 +[23:24:42.806] iteration 15700 [70.51 sec]: learning rate : 0.000500 loss : 0.465933 +[23:25:30.999] iteration 15800 [118.70 sec]: learning rate : 0.000500 loss : 0.496460 +[23:26:18.484] iteration 15900 [166.19 sec]: learning rate : 0.000500 loss : 0.527844 +[23:27:05.916] iteration 16000 [213.62 sec]: learning rate : 0.000500 loss : 0.400343 +[23:27:53.381] iteration 16100 [261.09 sec]: learning rate : 0.000500 loss : 0.558893 +[23:28:06.739] Epoch 27 Evaluation: +[23:29:37.536] average MSE: 0.05171343889849075 average PSNR: 25.900384189199873 average SSIM: 0.6999388633338733 +[23:30:12.127] iteration 16200 [34.57 sec]: learning rate : 0.000500 loss : 0.366677 +[23:31:00.325] iteration 16300 [82.77 sec]: learning rate : 0.000500 loss : 0.408079 +[23:31:48.295] iteration 16400 [130.74 sec]: learning rate : 0.000500 loss : 0.453931 +[23:32:35.755] iteration 16500 [178.20 sec]: learning rate : 0.000500 loss : 0.408804 +[23:33:23.695] iteration 16600 [226.14 sec]: learning rate : 0.000500 loss : 0.370503 +[23:34:11.238] iteration 16700 [273.68 sec]: learning rate : 0.000500 loss : 0.320842 +[23:34:13.151] Epoch 28 Evaluation: +[23:35:39.791] average MSE: 0.04946291490891425 average PSNR: 26.09431746586558 average SSIM: 0.7065642480263757 +[23:36:26.373] iteration 16800 [46.56 sec]: learning rate : 0.000500 loss : 0.694064 +[23:37:13.941] iteration 16900 [94.13 sec]: learning rate : 0.000500 loss : 0.518083 +[23:38:01.430] iteration 17000 [141.61 sec]: learning rate : 0.000500 loss : 0.475802 +[23:38:48.937] iteration 17100 [189.12 sec]: learning rate : 0.000500 loss : 0.476285 +[23:39:36.548] iteration 17200 [236.73 sec]: learning rate : 0.000500 loss : 0.422576 +[23:40:14.764] Epoch 29 Evaluation: +[23:41:46.738] average MSE: 0.050428875515407125 average PSNR: 26.00959563521974 average SSIM: 0.7035235911085882 +[23:41:56.432] iteration 17300 [9.67 sec]: learning rate : 0.000500 loss : 0.336243 +[23:42:43.852] iteration 17400 [57.09 sec]: learning rate : 0.000500 loss : 0.371049 +[23:43:31.507] iteration 17500 [104.75 sec]: learning rate : 0.000500 loss : 0.381733 +[23:44:19.052] iteration 17600 [152.29 sec]: learning rate : 0.000500 loss : 0.431927 +[23:45:06.448] iteration 17700 [199.69 sec]: learning rate : 0.000500 loss : 0.572642 +[23:45:53.943] iteration 17800 [247.18 sec]: learning rate : 0.000500 loss : 0.543105 +[23:46:20.508] Epoch 30 Evaluation: +[23:47:49.030] average MSE: 0.05065595707842352 average PSNR: 25.99019451324757 average SSIM: 0.7029947981640582 +[23:48:10.379] iteration 17900 [21.33 sec]: learning rate : 0.000500 loss : 0.523042 +[23:48:57.945] iteration 18000 [68.89 sec]: learning rate : 0.000500 loss : 0.308138 +[23:49:46.014] iteration 18100 [116.96 sec]: learning rate : 0.000500 loss : 0.420423 +[23:50:33.601] iteration 18200 [164.55 sec]: learning rate : 0.000500 loss : 0.703283 +[23:51:21.346] iteration 18300 [212.29 sec]: learning rate : 0.000500 loss : 0.451160 +[23:52:08.987] iteration 18400 [259.93 sec]: learning rate : 0.000500 loss : 0.581688 +[23:52:24.229] Epoch 31 Evaluation: +[23:53:53.196] average MSE: 0.05166531401133621 average PSNR: 25.904210680038027 average SSIM: 0.7000672694234195 +[23:54:25.730] iteration 18500 [32.51 sec]: learning rate : 0.000500 loss : 0.439048 +[23:55:13.390] iteration 18600 [80.17 sec]: learning rate : 0.000500 loss : 0.415238 +[23:56:01.808] iteration 18700 [128.59 sec]: learning rate : 0.000500 loss : 0.385705 +[23:56:49.405] iteration 18800 [176.19 sec]: learning rate : 0.000500 loss : 0.400538 +[23:57:37.290] iteration 18900 [224.07 sec]: learning rate : 0.000500 loss : 0.585303 +[23:58:26.169] iteration 19000 [272.95 sec]: learning rate : 0.000500 loss : 0.394843 +[23:58:29.985] Epoch 32 Evaluation: +[00:00:01.992] average MSE: 0.04980401202576362 average PSNR: 26.063964691047087 average SSIM: 0.706422141557223 +[00:00:45.834] iteration 19100 [43.82 sec]: learning rate : 0.000500 loss : 0.358874 +[00:01:33.364] iteration 19200 [91.35 sec]: learning rate : 0.000500 loss : 0.522801 +[00:02:20.745] iteration 19300 [138.73 sec]: learning rate : 0.000500 loss : 0.458177 +[00:03:08.647] iteration 19400 [186.63 sec]: learning rate : 0.000500 loss : 0.397748 +[00:03:56.164] iteration 19500 [234.15 sec]: learning rate : 0.000500 loss : 0.394857 +[00:04:36.540] Epoch 33 Evaluation: +[00:06:03.264] average MSE: 0.05214434070960973 average PSNR: 25.863415058367906 average SSIM: 0.699652888924183 +[00:06:11.109] iteration 19600 [7.82 sec]: learning rate : 0.000500 loss : 0.437926 +[00:06:58.889] iteration 19700 [55.60 sec]: learning rate : 0.000500 loss : 4.331016 +[00:07:46.824] iteration 19800 [103.54 sec]: learning rate : 0.000500 loss : 81.103569 +[00:08:34.218] iteration 19900 [150.93 sec]: learning rate : 0.000500 loss : 14.547447 +[00:09:21.723] iteration 20000 [198.44 sec]: learning rate : 0.000125 loss : 2.398916 +[00:09:21.881] save model to model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/iter_20000.pth +[00:10:09.961] iteration 20100 [246.68 sec]: learning rate : 0.000250 loss : 1.617102 +[00:10:39.066] Epoch 34 Evaluation: +[00:12:08.097] average MSE: 0.047605437799345146 average PSNR: 26.260991013316197 average SSIM: 0.7136253446956282 +[00:12:27.253] iteration 20200 [19.13 sec]: learning rate : 0.000250 loss : 0.518270 +[00:13:14.828] iteration 20300 [66.71 sec]: learning rate : 0.000250 loss : 0.593489 +[00:14:02.305] iteration 20400 [114.19 sec]: learning rate : 0.000250 loss : 0.769219 +[00:14:49.672] iteration 20500 [161.55 sec]: learning rate : 0.000250 loss : 0.651041 +[00:15:37.755] iteration 20600 [209.63 sec]: learning rate : 0.000250 loss : 0.775320 +[00:16:25.743] iteration 20700 [257.62 sec]: learning rate : 0.000250 loss : 0.554874 +[00:16:42.977] Epoch 35 Evaluation: +[00:18:14.054] average MSE: 0.04664626523563848 average PSNR: 26.348165431934056 average SSIM: 0.7174615310887414 +[00:18:44.599] iteration 20800 [30.52 sec]: learning rate : 0.000250 loss : 0.548372 +[00:19:32.223] iteration 20900 [78.15 sec]: learning rate : 0.000250 loss : 0.580007 +[00:20:19.695] iteration 21000 [125.62 sec]: learning rate : 0.000250 loss : 0.559660 +[00:21:07.539] iteration 21100 [173.46 sec]: learning rate : 0.000250 loss : 0.587914 +[00:21:55.054] iteration 21200 [220.98 sec]: learning rate : 0.000250 loss : 0.434778 +[00:22:42.594] iteration 21300 [268.52 sec]: learning rate : 0.000250 loss : 0.443477 +[00:22:48.316] Epoch 36 Evaluation: +[00:24:18.887] average MSE: 0.04815151014202667 average PSNR: 26.210608222555276 average SSIM: 0.7097672898884114 +[00:25:01.369] iteration 21400 [42.47 sec]: learning rate : 0.000250 loss : 0.660786 +[00:25:49.390] iteration 21500 [90.48 sec]: learning rate : 0.000250 loss : 0.370799 +[00:26:37.181] iteration 21600 [138.27 sec]: learning rate : 0.000250 loss : 0.487237 +[00:27:24.663] iteration 21700 [185.75 sec]: learning rate : 0.000250 loss : 0.565249 +[00:28:12.092] iteration 21800 [233.18 sec]: learning rate : 0.000250 loss : 0.358159 +[00:28:54.011] Epoch 37 Evaluation: +[00:30:21.398] average MSE: 0.04777468871645227 average PSNR: 26.244741986968865 average SSIM: 0.7133030050221283 +[00:30:27.360] iteration 21900 [5.94 sec]: learning rate : 0.000250 loss : 0.479141 +[00:31:14.872] iteration 22000 [53.45 sec]: learning rate : 0.000250 loss : 0.363191 +[00:32:02.552] iteration 22100 [101.13 sec]: learning rate : 0.000250 loss : 0.375342 +[00:32:50.248] iteration 22200 [148.83 sec]: learning rate : 0.000250 loss : 0.438995 +[00:33:37.758] iteration 22300 [196.34 sec]: learning rate : 0.000250 loss : 0.454713 +[00:34:25.649] iteration 22400 [244.23 sec]: learning rate : 0.000250 loss : 0.437832 +[00:34:56.194] Epoch 38 Evaluation: +[00:36:22.660] average MSE: 0.04791543969571445 average PSNR: 26.232290858913046 average SSIM: 0.7132477476515511 +[00:36:39.927] iteration 22500 [17.24 sec]: learning rate : 0.000250 loss : 0.515409 +[00:37:27.596] iteration 22600 [64.91 sec]: learning rate : 0.000250 loss : 0.483752 +[00:38:15.597] iteration 22700 [112.91 sec]: learning rate : 0.000250 loss : 1.483458 +[00:39:03.084] iteration 22800 [160.40 sec]: learning rate : 0.000250 loss : 0.442352 +[00:39:50.578] iteration 22900 [207.89 sec]: learning rate : 0.000250 loss : 0.488372 +[00:40:38.333] iteration 23000 [255.65 sec]: learning rate : 0.000250 loss : 0.502172 +[00:40:57.313] Epoch 39 Evaluation: +[00:42:23.840] average MSE: 0.04772741111759646 average PSNR: 26.248962165894262 average SSIM: 0.7131939153625835 +[00:42:52.600] iteration 23100 [28.74 sec]: learning rate : 0.000250 loss : 0.362155 +[00:43:40.690] iteration 23200 [76.83 sec]: learning rate : 0.000250 loss : 0.582984 +[00:44:28.354] iteration 23300 [124.49 sec]: learning rate : 0.000250 loss : 0.439951 +[00:45:16.015] iteration 23400 [172.15 sec]: learning rate : 0.000250 loss : 0.474683 +[00:46:03.587] iteration 23500 [219.72 sec]: learning rate : 0.000250 loss : 0.548576 +[00:46:51.206] iteration 23600 [267.34 sec]: learning rate : 0.000250 loss : 0.376893 +[00:46:58.825] Epoch 40 Evaluation: +[00:48:30.535] average MSE: 0.04731809911837422 average PSNR: 26.28590608566495 average SSIM: 0.7151082783233853 +[00:49:11.256] iteration 23700 [40.70 sec]: learning rate : 0.000250 loss : 0.488056 +[00:49:58.804] iteration 23800 [88.25 sec]: learning rate : 0.000250 loss : 0.560768 +[00:50:46.504] iteration 23900 [135.95 sec]: learning rate : 0.000250 loss : 0.380096 +[00:51:34.093] iteration 24000 [183.54 sec]: learning rate : 0.000250 loss : 0.437086 +[00:52:22.026] iteration 24100 [231.47 sec]: learning rate : 0.000250 loss : 0.403139 +[00:53:05.842] Epoch 41 Evaluation: +[00:54:36.326] average MSE: 0.0473525057813791 average PSNR: 26.282568982107463 average SSIM: 0.7154695642433281 +[00:54:40.325] iteration 24200 [3.98 sec]: learning rate : 0.000250 loss : 0.448072 +[00:55:28.328] iteration 24300 [51.98 sec]: learning rate : 0.000250 loss : 0.461091 +[00:56:15.730] iteration 24400 [99.38 sec]: learning rate : 0.000250 loss : 0.407469 +[00:57:03.215] iteration 24500 [146.87 sec]: learning rate : 0.000250 loss : 0.400850 +[00:57:50.588] iteration 24600 [194.24 sec]: learning rate : 0.000250 loss : 0.506545 +[00:58:38.115] iteration 24700 [241.77 sec]: learning rate : 0.000250 loss : 0.510071 +[00:59:10.349] Epoch 42 Evaluation: +[01:00:37.451] average MSE: 0.04723425914996036 average PSNR: 26.29373389704805 average SSIM: 0.7157659394398977 +[01:00:53.395] iteration 24800 [15.92 sec]: learning rate : 0.000250 loss : 0.484940 +[01:01:41.016] iteration 24900 [63.54 sec]: learning rate : 0.000250 loss : 0.513320 +[01:02:28.625] iteration 25000 [111.15 sec]: learning rate : 0.000250 loss : 0.478274 +[01:03:16.507] iteration 25100 [159.03 sec]: learning rate : 0.000250 loss : 0.537598 +[01:04:03.891] iteration 25200 [206.42 sec]: learning rate : 0.000250 loss : 0.557700 +[01:04:51.408] iteration 25300 [253.93 sec]: learning rate : 0.000250 loss : 0.469128 +[01:05:12.296] Epoch 43 Evaluation: +[01:06:39.915] average MSE: 0.04787634629671309 average PSNR: 26.234164435316412 average SSIM: 0.7138990593814742 +[01:07:06.822] iteration 25400 [26.88 sec]: learning rate : 0.000250 loss : 0.517124 +[01:07:54.236] iteration 25500 [74.30 sec]: learning rate : 0.000250 loss : 0.495025 +[01:08:41.704] iteration 25600 [121.76 sec]: learning rate : 0.000250 loss : 8.165817 +[01:09:29.169] iteration 25700 [169.25 sec]: learning rate : 0.000250 loss : 0.544526 +[01:10:17.476] iteration 25800 [217.54 sec]: learning rate : 0.000250 loss : 3.282400 +[01:11:04.965] iteration 25900 [265.03 sec]: learning rate : 0.000250 loss : 0.511600 +[01:11:14.688] Epoch 44 Evaluation: +[01:12:40.604] average MSE: 0.047418268036732006 average PSNR: 26.275206211658436 average SSIM: 0.7129969544396437 +[01:13:18.800] iteration 26000 [38.17 sec]: learning rate : 0.000250 loss : 0.360469 +[01:14:06.464] iteration 26100 [85.84 sec]: learning rate : 0.000250 loss : 0.425182 +[01:14:53.932] iteration 26200 [133.30 sec]: learning rate : 0.000250 loss : 0.426583 +[01:15:41.318] iteration 26300 [180.69 sec]: learning rate : 0.000250 loss : 0.462646 +[01:16:28.813] iteration 26400 [228.19 sec]: learning rate : 0.000250 loss : 4.231050 +[01:17:15.110] Epoch 45 Evaluation: +[01:18:45.397] average MSE: 0.04734854439659598 average PSNR: 26.281727824290982 average SSIM: 0.7142003282246115 +[01:18:47.505] iteration 26500 [2.09 sec]: learning rate : 0.000250 loss : 0.524553 +[01:19:34.957] iteration 26600 [49.54 sec]: learning rate : 0.000250 loss : 0.494790 +[01:20:22.470] iteration 26700 [97.05 sec]: learning rate : 0.000250 loss : 0.509060 +[01:21:09.942] iteration 26800 [144.52 sec]: learning rate : 0.000250 loss : 0.617178 +[01:21:57.324] iteration 26900 [191.90 sec]: learning rate : 0.000250 loss : 0.440680 +[01:22:45.363] iteration 27000 [239.94 sec]: learning rate : 0.000250 loss : 6.743861 +[01:23:19.593] Epoch 46 Evaluation: +[01:24:48.583] average MSE: 0.047772725392983856 average PSNR: 26.24600257115558 average SSIM: 0.7134961488112075 +[01:25:02.099] iteration 27100 [13.49 sec]: learning rate : 0.000250 loss : 0.365532 +[01:25:50.315] iteration 27200 [61.71 sec]: learning rate : 0.000250 loss : 0.471332 +[01:26:37.902] iteration 27300 [109.30 sec]: learning rate : 0.000250 loss : 0.392189 +[01:27:25.389] iteration 27400 [156.78 sec]: learning rate : 0.000250 loss : 0.559128 +[01:28:13.780] iteration 27500 [205.17 sec]: learning rate : 0.000250 loss : 0.500776 +[01:29:01.327] iteration 27600 [252.72 sec]: learning rate : 0.000250 loss : 0.457149 +[01:29:24.088] Epoch 47 Evaluation: +[01:30:50.801] average MSE: 0.04752282363232417 average PSNR: 26.26464810940733 average SSIM: 0.7121966082315254 +[01:31:15.675] iteration 27700 [24.85 sec]: learning rate : 0.000250 loss : 0.397298 +[01:32:03.316] iteration 27800 [72.49 sec]: learning rate : 0.000250 loss : 0.504179 +[01:32:50.814] iteration 27900 [119.99 sec]: learning rate : 0.000250 loss : 0.650060 +[01:33:39.008] iteration 28000 [168.18 sec]: learning rate : 0.000250 loss : 0.486569 +[01:34:26.633] iteration 28100 [215.81 sec]: learning rate : 0.000250 loss : 0.503464 +[01:35:14.204] iteration 28200 [263.38 sec]: learning rate : 0.000250 loss : 0.448663 +[01:35:25.764] Epoch 48 Evaluation: +[01:36:57.906] average MSE: 0.04839244934506186 average PSNR: 26.185816050710418 average SSIM: 0.7086252502193047 +[01:37:34.127] iteration 28300 [36.20 sec]: learning rate : 0.000250 loss : 0.348510 +[01:38:21.646] iteration 28400 [83.72 sec]: learning rate : 0.000250 loss : 0.408280 +[01:39:09.467] iteration 28500 [131.54 sec]: learning rate : 0.000250 loss : 0.420664 +[01:39:56.923] iteration 28600 [178.99 sec]: learning rate : 0.000250 loss : 0.398417 +[01:40:44.990] iteration 28700 [227.06 sec]: learning rate : 0.000250 loss : 0.514631 +[01:41:32.395] iteration 28800 [274.47 sec]: learning rate : 0.000250 loss : 0.422852 +[01:41:32.432] Epoch 49 Evaluation: +[01:42:59.287] average MSE: 0.04825374192466146 average PSNR: 26.197215400671656 average SSIM: 0.7090993999760924 +[01:43:47.151] iteration 28900 [47.84 sec]: learning rate : 0.000250 loss : 0.346344 +[01:44:34.694] iteration 29000 [95.39 sec]: learning rate : 0.000250 loss : 0.480683 +[01:45:22.593] iteration 29100 [143.28 sec]: learning rate : 0.000250 loss : 0.613521 +[01:46:10.614] iteration 29200 [191.31 sec]: learning rate : 0.000250 loss : 0.415421 +[01:46:58.143] iteration 29300 [238.83 sec]: learning rate : 0.000250 loss : 0.446860 +[01:47:34.297] Epoch 50 Evaluation: +[01:49:06.192] average MSE: 0.04884435463605748 average PSNR: 26.142816897343902 average SSIM: 0.707473994817096 +[01:49:17.769] iteration 29400 [11.55 sec]: learning rate : 0.000250 loss : 0.474821 +[01:50:05.331] iteration 29500 [59.11 sec]: learning rate : 0.000250 loss : 0.432274 +[01:50:53.235] iteration 29600 [107.02 sec]: learning rate : 0.000250 loss : 0.410026 +[01:51:40.693] iteration 29700 [154.48 sec]: learning rate : 0.000250 loss : 0.448890 +[01:52:28.179] iteration 29800 [201.96 sec]: learning rate : 0.000250 loss : 0.426962 +[01:53:15.564] iteration 29900 [249.35 sec]: learning rate : 0.000250 loss : 0.444194 +[01:53:40.319] Epoch 51 Evaluation: +[01:55:07.747] average MSE: 0.050085619415389436 average PSNR: 26.031550774970988 average SSIM: 0.7022258479075726 +[01:55:31.308] iteration 30000 [23.54 sec]: learning rate : 0.000250 loss : 0.508263 +[01:56:19.435] iteration 30100 [71.67 sec]: learning rate : 0.000250 loss : 0.584759 +[01:57:06.804] iteration 30200 [119.03 sec]: learning rate : 0.000250 loss : 31.227755 +[01:57:54.310] iteration 30300 [166.54 sec]: learning rate : 0.000250 loss : 0.516747 +[01:58:41.839] iteration 30400 [214.07 sec]: learning rate : 0.000250 loss : 0.417698 +[01:59:29.269] iteration 30500 [261.52 sec]: learning rate : 0.000250 loss : 0.440952 +[01:59:42.602] Epoch 52 Evaluation: +[02:01:09.303] average MSE: 0.04882368983518989 average PSNR: 26.14395602795087 average SSIM: 0.705819783324655 +[02:01:44.373] iteration 30600 [35.05 sec]: learning rate : 0.000250 loss : 0.681770 +[02:02:31.745] iteration 30700 [82.42 sec]: learning rate : 0.000250 loss : 0.504426 +[02:03:20.114] iteration 30800 [130.79 sec]: learning rate : 0.000250 loss : 0.432510 +[02:04:07.574] iteration 30900 [178.25 sec]: learning rate : 0.000250 loss : 0.456404 +[02:04:54.996] iteration 31000 [225.67 sec]: learning rate : 0.000250 loss : 0.480301 +[02:05:42.517] iteration 31100 [273.19 sec]: learning rate : 0.000250 loss : 0.405897 +[02:05:44.432] Epoch 53 Evaluation: +[02:07:11.406] average MSE: 0.04834143720517448 average PSNR: 26.18733164703609 average SSIM: 0.7099957968235537 +[02:07:57.392] iteration 31200 [45.96 sec]: learning rate : 0.000250 loss : 0.530602 +[02:08:44.745] iteration 31300 [93.32 sec]: learning rate : 0.000250 loss : 0.456990 +[02:09:32.207] iteration 31400 [140.78 sec]: learning rate : 0.000250 loss : 0.442935 +[02:10:19.698] iteration 31500 [188.27 sec]: learning rate : 0.000250 loss : 0.491820 +[02:11:07.682] iteration 31600 [236.25 sec]: learning rate : 0.000250 loss : 0.415733 +[02:11:45.927] Epoch 54 Evaluation: +[02:13:13.727] average MSE: 0.048470345648826074 average PSNR: 26.175049048272413 average SSIM: 0.7090659346824731 +[02:13:23.425] iteration 31700 [9.67 sec]: learning rate : 0.000250 loss : 0.418334 +[02:14:11.065] iteration 31800 [57.31 sec]: learning rate : 0.000250 loss : 0.459983 +[02:14:58.535] iteration 31900 [104.78 sec]: learning rate : 0.000250 loss : 0.464088 +[02:15:46.071] iteration 32000 [152.32 sec]: learning rate : 0.000250 loss : 0.504537 +[02:16:33.545] iteration 32100 [199.79 sec]: learning rate : 0.000250 loss : 0.526369 +[02:17:21.084] iteration 32200 [247.33 sec]: learning rate : 0.000250 loss : 0.446971 +[02:17:47.660] Epoch 55 Evaluation: +[02:19:17.988] average MSE: 0.04841296162046309 average PSNR: 26.18065787995967 average SSIM: 0.7106427102561423 +[02:19:39.192] iteration 32300 [21.18 sec]: learning rate : 0.000250 loss : 0.540159 +[02:20:26.584] iteration 32400 [68.57 sec]: learning rate : 0.000250 loss : 0.405148 +[02:21:14.424] iteration 32500 [116.41 sec]: learning rate : 0.000250 loss : 0.477586 +[02:22:01.950] iteration 32600 [163.94 sec]: learning rate : 0.000250 loss : 0.551084 +[02:22:49.401] iteration 32700 [211.39 sec]: learning rate : 0.000250 loss : 0.486596 +[02:23:36.919] iteration 32800 [258.91 sec]: learning rate : 0.000250 loss : 0.443410 +[02:23:52.622] Epoch 56 Evaluation: +[02:25:19.584] average MSE: 0.04897330001452219 average PSNR: 26.13035923333411 average SSIM: 0.7075064680843705 +[02:25:52.658] iteration 32900 [33.05 sec]: learning rate : 0.000250 loss : 0.461296 +[02:26:40.089] iteration 33000 [80.48 sec]: learning rate : 0.000250 loss : 0.295342 +[02:27:27.541] iteration 33100 [127.93 sec]: learning rate : 0.000250 loss : 0.455983 +[02:28:15.016] iteration 33200 [175.41 sec]: learning rate : 0.000250 loss : 0.553469 +[02:29:02.453] iteration 33300 [222.85 sec]: learning rate : 0.000250 loss : 0.509931 +[02:29:50.875] iteration 33400 [271.29 sec]: learning rate : 0.000250 loss : 0.506702 +[02:29:54.878] Epoch 57 Evaluation: +[02:31:21.954] average MSE: 0.04940862080884357 average PSNR: 26.09145824102127 average SSIM: 0.7080898715721351 +[02:32:05.865] iteration 33500 [43.89 sec]: learning rate : 0.000250 loss : 0.337266 +[02:32:53.436] iteration 33600 [91.46 sec]: learning rate : 0.000250 loss : 0.422372 +[02:33:41.526] iteration 33700 [139.55 sec]: learning rate : 0.000250 loss : 0.417385 +[02:34:29.021] iteration 33800 [187.04 sec]: learning rate : 0.000250 loss : 0.425321 +[02:35:16.869] iteration 33900 [234.89 sec]: learning rate : 0.000250 loss : 0.361534 +[02:35:56.733] Epoch 58 Evaluation: +[02:37:22.558] average MSE: 0.05003608392971598 average PSNR: 26.0359288267868 average SSIM: 0.7061250150559997 +[02:37:30.508] iteration 34000 [7.93 sec]: learning rate : 0.000250 loss : 0.452361 +[02:38:17.901] iteration 34100 [55.32 sec]: learning rate : 0.000250 loss : 0.446801 +[02:39:05.630] iteration 34200 [103.05 sec]: learning rate : 0.000250 loss : 0.671751 +[02:39:53.123] iteration 34300 [150.54 sec]: learning rate : 0.000250 loss : 0.444180 +[02:40:41.385] iteration 34400 [198.80 sec]: learning rate : 0.000250 loss : 2.102408 +[02:41:28.903] iteration 34500 [246.32 sec]: learning rate : 0.000250 loss : 0.605284 +[02:41:57.435] Epoch 59 Evaluation: +[02:43:23.867] average MSE: 0.05033776382184068 average PSNR: 26.009501155497496 average SSIM: 0.7042945304916827 +[02:43:43.127] iteration 34600 [19.24 sec]: learning rate : 0.000250 loss : 0.426081 +[02:44:30.759] iteration 34700 [66.87 sec]: learning rate : 0.000250 loss : 0.475611 +[02:45:18.222] iteration 34800 [114.33 sec]: learning rate : 0.000250 loss : 0.508738 +[02:46:05.579] iteration 34900 [161.71 sec]: learning rate : 0.000250 loss : 0.370138 +[02:46:53.430] iteration 35000 [209.54 sec]: learning rate : 0.000250 loss : 0.434316 +[02:47:41.021] iteration 35100 [257.13 sec]: learning rate : 0.000250 loss : 0.433193 +[02:47:58.103] Epoch 60 Evaluation: +[02:49:24.224] average MSE: 0.04939462389005895 average PSNR: 26.09459040708071 average SSIM: 0.7112206418230942 +[02:49:54.863] iteration 35200 [30.62 sec]: learning rate : 0.000250 loss : 0.447790 +[02:50:42.528] iteration 35300 [78.28 sec]: learning rate : 0.000250 loss : 0.465624 +[02:51:30.157] iteration 35400 [125.91 sec]: learning rate : 0.000250 loss : 0.405520 +[02:52:18.013] iteration 35500 [173.77 sec]: learning rate : 0.000250 loss : 2.078802 +[02:53:05.532] iteration 35600 [221.28 sec]: learning rate : 0.000250 loss : 0.398493 +[02:53:52.946] iteration 35700 [268.70 sec]: learning rate : 0.000250 loss : 0.330938 +[02:53:58.741] Epoch 61 Evaluation: +[02:55:25.228] average MSE: 0.04913338185338563 average PSNR: 26.116723338560263 average SSIM: 0.712047166757096 +[02:56:07.577] iteration 35800 [42.33 sec]: learning rate : 0.000250 loss : 0.559209 +[02:56:55.608] iteration 35900 [90.36 sec]: learning rate : 0.000250 loss : 0.379301 +[02:57:43.409] iteration 36000 [138.16 sec]: learning rate : 0.000250 loss : 0.461107 +[02:58:30.885] iteration 36100 [185.63 sec]: learning rate : 0.000250 loss : 0.308977 +[02:59:18.559] iteration 36200 [233.31 sec]: learning rate : 0.000250 loss : 0.491302 +[03:00:00.341] Epoch 62 Evaluation: +[03:01:26.719] average MSE: 0.04921022910821607 average PSNR: 26.11018282980537 average SSIM: 0.7121284070516213 +[03:01:32.636] iteration 36300 [5.89 sec]: learning rate : 0.000250 loss : 0.349248 +[03:02:20.261] iteration 36400 [53.52 sec]: learning rate : 0.000250 loss : 0.374848 +[03:03:08.825] iteration 36500 [102.08 sec]: learning rate : 0.000250 loss : 0.368675 +[03:03:56.329] iteration 36600 [149.59 sec]: learning rate : 0.000250 loss : 0.435133 +[03:04:43.945] iteration 36700 [197.20 sec]: learning rate : 0.000250 loss : 0.469404 +[03:05:31.684] iteration 36800 [244.94 sec]: learning rate : 0.000250 loss : 0.458239 +[03:06:02.133] Epoch 63 Evaluation: +[03:07:29.365] average MSE: 0.04946440047872141 average PSNR: 26.087093186954043 average SSIM: 0.7122166062813089 +[03:07:46.652] iteration 36900 [17.26 sec]: learning rate : 0.000250 loss : 0.627547 +[03:08:34.809] iteration 37000 [65.42 sec]: learning rate : 0.000250 loss : 0.511343 +[03:09:22.182] iteration 37100 [112.79 sec]: learning rate : 0.000250 loss : 0.487277 +[03:10:09.696] iteration 37200 [160.31 sec]: learning rate : 0.000250 loss : 0.522611 +[03:10:57.817] iteration 37300 [208.43 sec]: learning rate : 0.000250 loss : 0.593651 +[03:11:45.224] iteration 37400 [255.84 sec]: learning rate : 0.000250 loss : 0.400859 +[03:12:04.215] Epoch 64 Evaluation: +[03:13:31.357] average MSE: 0.049853177331260254 average PSNR: 26.05289422036699 average SSIM: 0.7084155540811947 +[03:14:00.157] iteration 37500 [28.78 sec]: learning rate : 0.000250 loss : 0.609452 +[03:14:48.639] iteration 37600 [77.26 sec]: learning rate : 0.000250 loss : 0.471964 +[03:15:35.992] iteration 37700 [124.61 sec]: learning rate : 0.000250 loss : 0.534148 +[03:16:23.488] iteration 37800 [172.11 sec]: learning rate : 0.000250 loss : 0.500419 +[03:17:10.906] iteration 37900 [219.53 sec]: learning rate : 0.000250 loss : 0.464976 +[03:17:58.401] iteration 38000 [267.02 sec]: learning rate : 0.000250 loss : 0.427574 +[03:18:06.002] Epoch 65 Evaluation: +[03:19:33.275] average MSE: 0.0496268207705901 average PSNR: 26.074300173073816 average SSIM: 0.7097539345161622 +[03:20:14.028] iteration 38100 [40.73 sec]: learning rate : 0.000250 loss : 0.360187 +[03:21:01.438] iteration 38200 [88.14 sec]: learning rate : 0.000250 loss : 0.491831 +[03:21:48.964] iteration 38300 [135.67 sec]: learning rate : 0.000250 loss : 0.492821 +[03:22:36.492] iteration 38400 [183.19 sec]: learning rate : 0.000250 loss : 0.460588 +[03:23:24.287] iteration 38500 [230.99 sec]: learning rate : 0.000250 loss : 0.520835 +[03:24:08.047] Epoch 66 Evaluation: +[03:25:37.854] average MSE: 0.04930021783107513 average PSNR: 26.102616403397906 average SSIM: 0.7096246602578911 +[03:25:42.353] iteration 38600 [4.48 sec]: learning rate : 0.000250 loss : 0.536835 +[03:26:30.079] iteration 38700 [52.20 sec]: learning rate : 0.000250 loss : 0.431708 +[03:27:17.573] iteration 38800 [99.70 sec]: learning rate : 0.000250 loss : 0.398314 +[03:28:05.232] iteration 38900 [147.36 sec]: learning rate : 0.000250 loss : 0.360378 +[03:28:52.765] iteration 39000 [194.89 sec]: learning rate : 0.000250 loss : 0.328919 +[03:29:40.404] iteration 39100 [242.53 sec]: learning rate : 0.000250 loss : 0.402142 +[03:30:12.807] Epoch 67 Evaluation: +[03:31:42.497] average MSE: 0.0512965844405298 average PSNR: 25.929576936891078 average SSIM: 0.7022896807705895 +[03:31:57.874] iteration 39200 [15.35 sec]: learning rate : 0.000250 loss : 0.396311 +[03:32:45.719] iteration 39300 [63.20 sec]: learning rate : 0.000250 loss : 0.456427 +[03:33:33.692] iteration 39400 [111.17 sec]: learning rate : 0.000250 loss : 0.592444 +[03:34:21.221] iteration 39500 [158.70 sec]: learning rate : 0.000250 loss : 0.443138 +[03:35:08.679] iteration 39600 [206.16 sec]: learning rate : 0.000250 loss : 0.529652 +[03:35:56.205] iteration 39700 [253.69 sec]: learning rate : 0.000250 loss : 0.328401 +[03:36:17.147] Epoch 68 Evaluation: +[03:37:47.118] average MSE: 0.05216179099137225 average PSNR: 25.857166828645838 average SSIM: 0.6986169182990727 +[03:38:14.130] iteration 39800 [26.99 sec]: learning rate : 0.000250 loss : 0.363996 +[03:39:01.718] iteration 39900 [74.58 sec]: learning rate : 0.000250 loss : 0.488752 +[03:39:49.366] iteration 40000 [122.22 sec]: learning rate : 0.000063 loss : 0.525329 +[03:39:49.525] save model to model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/iter_40000.pth +[03:40:37.209] iteration 40100 [170.09 sec]: learning rate : 0.000125 loss : 0.404352 +[03:41:25.196] iteration 40200 [218.05 sec]: learning rate : 0.000125 loss : 0.445576 +[03:42:13.159] iteration 40300 [266.02 sec]: learning rate : 0.000125 loss : 0.479373 +[03:42:22.652] Epoch 69 Evaluation: +[03:43:48.917] average MSE: 0.04971521083575557 average PSNR: 26.066283341541155 average SSIM: 0.7085844013301866 +[03:44:27.204] iteration 40400 [38.26 sec]: learning rate : 0.000125 loss : 0.455197 +[03:45:14.833] iteration 40500 [85.89 sec]: learning rate : 0.000125 loss : 0.392522 +[03:46:02.325] iteration 40600 [133.39 sec]: learning rate : 0.000125 loss : 0.414395 +[03:46:49.778] iteration 40700 [180.84 sec]: learning rate : 0.000125 loss : 0.427129 +[03:47:37.272] iteration 40800 [228.33 sec]: learning rate : 0.000125 loss : 0.409371 +[03:48:23.714] Epoch 70 Evaluation: +[03:49:50.081] average MSE: 0.0506631621480412 average PSNR: 25.984032217953736 average SSIM: 0.7043218875116178 +[03:49:52.191] iteration 40900 [2.09 sec]: learning rate : 0.000125 loss : 0.486466 +[03:50:40.273] iteration 41000 [50.17 sec]: learning rate : 0.000125 loss : 0.499510 +[03:51:27.902] iteration 41100 [97.80 sec]: learning rate : 0.000125 loss : 0.425045 +[03:52:15.408] iteration 41200 [145.30 sec]: learning rate : 0.000125 loss : 0.468339 +[03:53:03.140] iteration 41300 [193.04 sec]: learning rate : 0.000125 loss : 0.322854 +[03:53:51.254] iteration 41400 [241.15 sec]: learning rate : 0.000125 loss : 0.426397 +[03:54:25.518] Epoch 71 Evaluation: +[03:55:56.565] average MSE: 0.05075612572079915 average PSNR: 25.975787748954577 average SSIM: 0.7045221443990438 +[03:56:10.067] iteration 41500 [13.48 sec]: learning rate : 0.000125 loss : 0.330029 +[03:56:57.683] iteration 41600 [61.09 sec]: learning rate : 0.000125 loss : 0.418698 +[03:57:45.165] iteration 41700 [108.58 sec]: learning rate : 0.000125 loss : 0.755813 +[03:58:32.560] iteration 41800 [155.97 sec]: learning rate : 0.000125 loss : 3.410006 +[03:59:21.042] iteration 41900 [204.45 sec]: learning rate : 0.000125 loss : 0.389112 +[04:00:08.545] iteration 42000 [251.96 sec]: learning rate : 0.000125 loss : 0.417183 +[04:00:31.486] Epoch 72 Evaluation: +[04:01:57.543] average MSE: 0.052005801564731356 average PSNR: 25.87096980837016 average SSIM: 0.7004122636930717 +[04:02:22.462] iteration 42100 [24.90 sec]: learning rate : 0.000125 loss : 0.273352 +[04:03:10.226] iteration 42200 [72.66 sec]: learning rate : 0.000125 loss : 0.405133 +[04:03:58.164] iteration 42300 [120.60 sec]: learning rate : 0.000125 loss : 0.310748 +[04:04:46.580] iteration 42400 [169.02 sec]: learning rate : 0.000125 loss : 0.385165 +[04:05:34.129] iteration 42500 [216.56 sec]: learning rate : 0.000125 loss : 0.515553 +[04:06:21.536] iteration 42600 [263.97 sec]: learning rate : 0.000125 loss : 0.342173 +[04:06:33.057] Epoch 73 Evaluation: +[04:08:01.506] average MSE: 0.05099872319908819 average PSNR: 25.955862012427268 average SSIM: 0.7038189673176911 +[04:08:37.742] iteration 42700 [36.21 sec]: learning rate : 0.000125 loss : 0.351253 +[04:09:25.304] iteration 42800 [83.78 sec]: learning rate : 0.000125 loss : 0.347036 +[04:10:13.393] iteration 42900 [131.87 sec]: learning rate : 0.000125 loss : 2.616907 +[04:11:01.427] iteration 43000 [179.90 sec]: learning rate : 0.000125 loss : 0.642521 +[04:11:49.002] iteration 43100 [227.47 sec]: learning rate : 0.000125 loss : 0.331886 +[04:12:36.485] iteration 43200 [274.96 sec]: learning rate : 0.000125 loss : 0.467551 +[04:12:36.523] Epoch 74 Evaluation: +[04:14:05.260] average MSE: 0.05064485806831851 average PSNR: 25.984845568493437 average SSIM: 0.7060252432306117 +[04:14:53.051] iteration 43300 [47.77 sec]: learning rate : 0.000125 loss : 0.396817 +[04:15:40.941] iteration 43400 [95.66 sec]: learning rate : 0.000125 loss : 0.523029 +[04:16:28.560] iteration 43500 [143.28 sec]: learning rate : 0.000125 loss : 0.425380 +[04:17:16.673] iteration 43600 [191.39 sec]: learning rate : 0.000125 loss : 0.374781 +[04:18:04.068] iteration 43700 [238.78 sec]: learning rate : 0.000125 loss : 0.626706 +[04:18:40.499] Epoch 75 Evaluation: +[04:20:10.769] average MSE: 0.05156987371931179 average PSNR: 25.907744500589924 average SSIM: 0.6992234359789052 +[04:20:22.355] iteration 43800 [11.56 sec]: learning rate : 0.000125 loss : 0.417601 +[04:21:10.660] iteration 43900 [59.87 sec]: learning rate : 0.000125 loss : 0.343954 +[04:21:58.227] iteration 44000 [107.44 sec]: learning rate : 0.000125 loss : 0.391617 +[04:22:45.916] iteration 44100 [155.13 sec]: learning rate : 0.000125 loss : 7.514569 +[04:23:33.591] iteration 44200 [202.80 sec]: learning rate : 0.000125 loss : 0.527096 +[04:24:21.169] iteration 44300 [250.38 sec]: learning rate : 0.000125 loss : 0.391955 +[04:24:46.035] Epoch 76 Evaluation: +[04:26:18.167] average MSE: 0.05065669610063037 average PSNR: 25.984331406339322 average SSIM: 0.7054240139447392 +[04:26:41.254] iteration 44400 [23.06 sec]: learning rate : 0.000125 loss : 0.349485 +[04:27:29.255] iteration 44500 [71.06 sec]: learning rate : 0.000125 loss : 0.333133 +[04:28:16.694] iteration 44600 [118.50 sec]: learning rate : 0.000125 loss : 0.392055 +[04:29:04.251] iteration 44700 [166.06 sec]: learning rate : 0.000125 loss : 0.315268 +[04:29:51.841] iteration 44800 [213.65 sec]: learning rate : 0.000125 loss : 0.465278 +[04:30:39.428] iteration 44900 [261.24 sec]: learning rate : 0.000125 loss : 0.312571 +[04:30:52.720] Epoch 77 Evaluation: +[04:32:20.446] average MSE: 0.051124642418545925 average PSNR: 25.944116421680793 average SSIM: 0.7027896418151469 +[04:32:55.048] iteration 45000 [34.58 sec]: learning rate : 0.000125 loss : 0.315237 +[04:33:43.032] iteration 45100 [82.56 sec]: learning rate : 0.000125 loss : 0.416850 +[04:34:30.904] iteration 45200 [130.44 sec]: learning rate : 0.000125 loss : 0.422601 +[04:35:18.450] iteration 45300 [177.98 sec]: learning rate : 0.000125 loss : 0.351942 +[04:36:05.863] iteration 45400 [225.39 sec]: learning rate : 0.000125 loss : 0.435748 +[04:36:53.369] iteration 45500 [272.90 sec]: learning rate : 0.000125 loss : 0.465569 +[04:36:55.277] Epoch 78 Evaluation: +[04:38:23.570] average MSE: 0.05057343629203181 average PSNR: 25.991277173941295 average SSIM: 0.7043114266663218 +[04:39:09.573] iteration 45600 [45.98 sec]: learning rate : 0.000125 loss : 0.467917 +[04:39:57.061] iteration 45700 [93.47 sec]: learning rate : 0.000125 loss : 0.400180 +[04:40:44.677] iteration 45800 [141.08 sec]: learning rate : 0.000125 loss : 0.544391 +[04:41:32.634] iteration 45900 [189.04 sec]: learning rate : 0.000125 loss : 0.407637 +[04:42:20.038] iteration 46000 [236.44 sec]: learning rate : 0.000125 loss : 0.421593 +[04:42:58.080] Epoch 79 Evaluation: +[04:44:27.063] average MSE: 0.05133531972533945 average PSNR: 25.92554802097303 average SSIM: 0.7021258869228956 +[04:44:36.782] iteration 46100 [9.70 sec]: learning rate : 0.000125 loss : 0.415320 +[04:45:24.312] iteration 46200 [57.23 sec]: learning rate : 0.000125 loss : 0.409624 +[04:46:11.993] iteration 46300 [104.91 sec]: learning rate : 0.000125 loss : 0.427128 +[04:46:59.642] iteration 46400 [152.56 sec]: learning rate : 0.000125 loss : 0.468413 +[04:47:47.182] iteration 46500 [200.10 sec]: learning rate : 0.000125 loss : 0.494898 +[04:48:35.222] iteration 46600 [248.14 sec]: learning rate : 0.000125 loss : 0.594499 +[04:49:01.794] Epoch 80 Evaluation: +[04:50:29.282] average MSE: 0.05242181947893549 average PSNR: 25.833677791935493 average SSIM: 0.6989776212437099 +[04:50:50.554] iteration 46700 [21.25 sec]: learning rate : 0.000125 loss : 0.577253 +[04:51:38.064] iteration 46800 [68.76 sec]: learning rate : 0.000125 loss : 0.352247 +[04:52:26.054] iteration 46900 [116.75 sec]: learning rate : 0.000125 loss : 0.442742 +[04:53:13.719] iteration 47000 [164.41 sec]: learning rate : 0.000125 loss : 0.504491 +[04:54:01.124] iteration 47100 [211.82 sec]: learning rate : 0.000125 loss : 0.408654 +[04:54:49.117] iteration 47200 [259.81 sec]: learning rate : 0.000125 loss : 0.470808 +[04:55:04.296] Epoch 81 Evaluation: +[04:56:32.340] average MSE: 0.05273549994095615 average PSNR: 25.808884173418686 average SSIM: 0.6973754475320876 +[04:57:04.969] iteration 47300 [32.61 sec]: learning rate : 0.000125 loss : 0.362011 +[04:57:52.753] iteration 47400 [80.39 sec]: learning rate : 0.000125 loss : 0.231549 +[04:58:40.387] iteration 47500 [128.02 sec]: learning rate : 0.000125 loss : 0.402872 +[04:59:27.933] iteration 47600 [175.57 sec]: learning rate : 0.000125 loss : 0.439481 +[05:00:15.970] iteration 47700 [223.61 sec]: learning rate : 0.000125 loss : 0.439700 +[05:01:03.553] iteration 47800 [271.19 sec]: learning rate : 0.000125 loss : 0.397731 +[05:01:07.352] Epoch 82 Evaluation: +[05:02:33.630] average MSE: 0.05306480498888812 average PSNR: 25.780274966481773 average SSIM: 0.6970411406447513 +[05:03:17.651] iteration 47900 [44.00 sec]: learning rate : 0.000125 loss : 0.360061 +[05:04:05.786] iteration 48000 [92.13 sec]: learning rate : 0.000125 loss : 0.426150 +[05:04:53.400] iteration 48100 [139.75 sec]: learning rate : 0.000125 loss : 0.453125 +[05:05:40.822] iteration 48200 [187.17 sec]: learning rate : 0.000125 loss : 0.446904 +[05:06:29.008] iteration 48300 [235.36 sec]: learning rate : 0.000125 loss : 0.337565 +[05:07:08.914] Epoch 83 Evaluation: +[05:08:36.731] average MSE: 0.051086731476897386 average PSNR: 25.947167761195804 average SSIM: 0.7036847260831663 +[05:08:44.690] iteration 48400 [7.93 sec]: learning rate : 0.000125 loss : 0.393589 +[05:09:32.077] iteration 48500 [55.32 sec]: learning rate : 0.000125 loss : 0.469238 +[05:10:20.001] iteration 48600 [103.24 sec]: learning rate : 0.000125 loss : 0.481110 +[05:11:08.103] iteration 48700 [151.35 sec]: learning rate : 0.000125 loss : 0.367759 +[05:11:56.167] iteration 48800 [199.41 sec]: learning rate : 0.000125 loss : 0.351253 +[05:12:43.670] iteration 48900 [246.91 sec]: learning rate : 0.000125 loss : 0.358757 +[05:13:12.193] Epoch 84 Evaluation: +[05:14:41.744] average MSE: 0.051252256840722515 average PSNR: 25.932837572196973 average SSIM: 0.7025823308146151 +[05:15:00.985] iteration 49000 [19.22 sec]: learning rate : 0.000125 loss : 0.445520 +[05:15:48.720] iteration 49100 [66.95 sec]: learning rate : 0.000125 loss : 0.390584 +[05:16:36.344] iteration 49200 [114.58 sec]: learning rate : 0.000125 loss : 0.462672 +[05:17:24.251] iteration 49300 [162.49 sec]: learning rate : 0.000125 loss : 0.553137 +[05:18:11.912] iteration 49400 [210.15 sec]: learning rate : 0.000125 loss : 0.398446 +[05:19:00.383] iteration 49500 [258.62 sec]: learning rate : 0.000125 loss : 0.505120 +[05:19:17.456] Epoch 85 Evaluation: +[05:20:46.538] average MSE: 0.05209866368187377 average PSNR: 25.86177263341184 average SSIM: 0.6997287788999473 +[05:21:17.060] iteration 49600 [30.50 sec]: learning rate : 0.000125 loss : 0.412194 +[05:22:04.608] iteration 49700 [78.05 sec]: learning rate : 0.000125 loss : 0.558471 +[05:22:52.508] iteration 49800 [125.95 sec]: learning rate : 0.000125 loss : 0.509288 +[05:23:40.063] iteration 49900 [173.50 sec]: learning rate : 0.000125 loss : 0.462910 +[05:24:27.763] iteration 50000 [221.20 sec]: learning rate : 0.000125 loss : 0.448631 +[05:25:15.313] iteration 50100 [268.75 sec]: learning rate : 0.000125 loss : 0.407378 +[05:25:21.153] Epoch 86 Evaluation: +[05:26:49.075] average MSE: 0.04997787427800111 average PSNR: 26.042720956764114 average SSIM: 0.7071021466249962 +[05:27:30.981] iteration 50200 [41.88 sec]: learning rate : 0.000125 loss : 14.448176 +[05:28:19.563] iteration 50300 [90.47 sec]: learning rate : 0.000125 loss : 0.406659 +[05:29:07.107] iteration 50400 [138.01 sec]: learning rate : 0.000125 loss : 0.384816 +[05:29:54.789] iteration 50500 [185.69 sec]: learning rate : 0.000125 loss : 0.322892 +[05:30:42.484] iteration 50600 [233.39 sec]: learning rate : 0.000125 loss : 0.389441 +[05:31:24.327] Epoch 87 Evaluation: +[05:32:55.428] average MSE: 0.051512356821650396 average PSNR: 25.911588153378315 average SSIM: 0.7011067859105818 +[05:33:01.370] iteration 50700 [5.92 sec]: learning rate : 0.000125 loss : 0.362410 +[05:33:50.240] iteration 50800 [54.79 sec]: learning rate : 0.000125 loss : 0.345445 +[05:34:37.697] iteration 50900 [102.25 sec]: learning rate : 0.000125 loss : 0.341241 +[05:35:25.057] iteration 51000 [149.61 sec]: learning rate : 0.000125 loss : 0.389284 +[05:36:12.510] iteration 51100 [197.06 sec]: learning rate : 0.000125 loss : 0.385948 +[05:37:00.432] iteration 51200 [244.98 sec]: learning rate : 0.000125 loss : 0.495439 +[05:37:30.939] Epoch 88 Evaluation: +[05:38:57.218] average MSE: 0.05158549512260951 average PSNR: 25.904743999154263 average SSIM: 0.7009837811538917 +[05:39:14.499] iteration 51300 [17.26 sec]: learning rate : 0.000125 loss : 0.485529 +[05:40:02.509] iteration 51400 [65.27 sec]: learning rate : 0.000125 loss : 0.353245 +[05:40:49.885] iteration 51500 [112.65 sec]: learning rate : 0.000125 loss : 0.355025 +[05:41:37.726] iteration 51600 [160.49 sec]: learning rate : 0.000125 loss : 0.390495 +[05:42:25.505] iteration 51700 [208.27 sec]: learning rate : 0.000125 loss : 0.414763 +[05:43:13.002] iteration 51800 [255.76 sec]: learning rate : 0.000125 loss : 0.525905 +[05:43:32.145] Epoch 89 Evaluation: +[05:45:03.638] average MSE: 0.051465289725215684 average PSNR: 25.914866453064043 average SSIM: 0.7014979970204313 +[05:45:32.378] iteration 51900 [28.72 sec]: learning rate : 0.000125 loss : 0.458942 +[05:46:20.482] iteration 52000 [76.82 sec]: learning rate : 0.000125 loss : 0.601490 +[05:47:07.962] iteration 52100 [124.30 sec]: learning rate : 0.000125 loss : 1.017944 +[05:47:55.551] iteration 52200 [171.89 sec]: learning rate : 0.000125 loss : 0.528862 +[05:48:43.477] iteration 52300 [219.82 sec]: learning rate : 0.000125 loss : 0.390984 +[05:49:30.995] iteration 52400 [267.33 sec]: learning rate : 0.000125 loss : 0.404065 +[05:49:38.619] Epoch 90 Evaluation: +[05:51:08.890] average MSE: 0.0494939722682545 average PSNR: 26.085317557521975 average SSIM: 0.7067528889061694 +[05:51:49.143] iteration 52500 [40.23 sec]: learning rate : 0.000125 loss : 0.349490 +[05:52:36.648] iteration 52600 [87.73 sec]: learning rate : 0.000125 loss : 0.446276 +[05:53:24.292] iteration 52700 [135.38 sec]: learning rate : 0.000125 loss : 0.388837 +[05:54:11.861] iteration 52800 [182.95 sec]: learning rate : 0.000125 loss : 0.371979 +[05:54:59.932] iteration 52900 [231.02 sec]: learning rate : 0.000125 loss : 0.449013 +[05:55:43.697] Epoch 91 Evaluation: +[05:57:14.095] average MSE: 0.05451487941823655 average PSNR: 25.660207662959703 average SSIM: 0.6918559929512266 +[05:57:18.137] iteration 53000 [4.02 sec]: learning rate : 0.000125 loss : 0.502875 +[05:58:05.730] iteration 53100 [51.61 sec]: learning rate : 0.000125 loss : 0.544498 +[05:58:53.104] iteration 53200 [98.99 sec]: learning rate : 0.000125 loss : 0.358594 +[05:59:40.602] iteration 53300 [146.48 sec]: learning rate : 0.000125 loss : 0.383613 +[06:00:28.129] iteration 53400 [194.01 sec]: learning rate : 0.000125 loss : 0.448052 +[06:01:15.532] iteration 53500 [241.41 sec]: learning rate : 0.000125 loss : 0.472628 +[06:01:48.310] Epoch 92 Evaluation: +[06:03:16.459] average MSE: 0.04937792787138451 average PSNR: 26.098244057455933 average SSIM: 0.7082501639358671 +[06:03:32.495] iteration 53600 [16.01 sec]: learning rate : 0.000125 loss : 0.519296 +[06:04:20.152] iteration 53700 [63.67 sec]: learning rate : 0.000125 loss : 0.408192 +[06:05:07.663] iteration 53800 [111.18 sec]: learning rate : 0.000125 loss : 0.460431 +[06:05:55.258] iteration 53900 [158.78 sec]: learning rate : 0.000125 loss : 0.378158 +[06:06:42.784] iteration 54000 [206.30 sec]: learning rate : 0.000125 loss : 0.414943 +[06:07:30.749] iteration 54100 [254.27 sec]: learning rate : 0.000125 loss : 0.431237 +[06:07:51.696] Epoch 93 Evaluation: +[06:09:19.211] average MSE: 0.05386590496874594 average PSNR: 25.716009163280997 average SSIM: 0.6937890167078737 +[06:09:46.118] iteration 54200 [26.88 sec]: learning rate : 0.000125 loss : 0.483820 +[06:10:33.497] iteration 54300 [74.26 sec]: learning rate : 0.000125 loss : 0.462408 +[06:11:21.656] iteration 54400 [122.42 sec]: learning rate : 0.000125 loss : 0.386495 +[06:12:09.205] iteration 54500 [169.97 sec]: learning rate : 0.000125 loss : 0.437098 +[06:12:57.645] iteration 54600 [218.41 sec]: learning rate : 0.000125 loss : 0.402772 +[06:13:45.146] iteration 54700 [265.91 sec]: learning rate : 0.000125 loss : 0.499657 +[06:13:54.638] Epoch 94 Evaluation: +[06:15:21.207] average MSE: 0.05239693861094071 average PSNR: 25.837937651630085 average SSIM: 0.6984390666717163 +[06:15:59.496] iteration 54800 [38.27 sec]: learning rate : 0.000125 loss : 0.400136 +[06:16:46.880] iteration 54900 [85.65 sec]: learning rate : 0.000125 loss : 0.366892 +[06:17:34.358] iteration 55000 [133.13 sec]: learning rate : 0.000125 loss : 0.374269 +[06:18:22.653] iteration 55100 [181.42 sec]: learning rate : 0.000125 loss : 0.489497 +[06:19:10.591] iteration 55200 [229.36 sec]: learning rate : 0.000125 loss : 0.477412 +[06:19:56.203] Epoch 95 Evaluation: +[06:21:23.446] average MSE: 0.051836654641206924 average PSNR: 25.8866247638652 average SSIM: 0.698541682232876 +[06:21:25.552] iteration 55300 [2.08 sec]: learning rate : 0.000125 loss : 0.411819 +[06:22:13.087] iteration 55400 [49.62 sec]: learning rate : 0.000125 loss : 0.402941 +[06:23:00.642] iteration 55500 [97.17 sec]: learning rate : 0.000125 loss : 0.444186 +[06:23:48.123] iteration 55600 [144.66 sec]: learning rate : 0.000125 loss : 0.596707 +[06:24:36.127] iteration 55700 [192.66 sec]: learning rate : 0.000125 loss : 0.482290 +[06:25:23.635] iteration 55800 [240.17 sec]: learning rate : 0.000125 loss : 0.388847 +[06:25:58.148] Epoch 96 Evaluation: +[06:27:26.380] average MSE: 0.0851875384766181 average PSNR: 23.731126975897745 average SSIM: 0.6330967622876165 +[06:27:40.069] iteration 55900 [13.67 sec]: learning rate : 0.000125 loss : 0.431111 +[06:28:27.581] iteration 56000 [61.18 sec]: learning rate : 0.000125 loss : 0.487189 +[06:29:15.157] iteration 56100 [108.76 sec]: learning rate : 0.000125 loss : 0.484160 +[06:30:03.259] iteration 56200 [156.86 sec]: learning rate : 0.000125 loss : 0.555880 +[06:30:51.087] iteration 56300 [204.68 sec]: learning rate : 0.000125 loss : 0.486808 +[06:31:38.662] iteration 56400 [252.26 sec]: learning rate : 0.000125 loss : 0.673114 +[06:32:01.462] Epoch 97 Evaluation: +[06:33:28.131] average MSE: 0.08800669670416085 average PSNR: 23.587413585228997 average SSIM: 0.6327868565753291 +[06:33:53.087] iteration 56500 [24.93 sec]: learning rate : 0.000125 loss : 0.365050 +[06:34:40.680] iteration 56600 [72.53 sec]: learning rate : 0.000125 loss : 0.360132 +[06:35:28.599] iteration 56700 [120.44 sec]: learning rate : 0.000125 loss : 0.325699 +[06:36:15.983] iteration 56800 [167.83 sec]: learning rate : 0.000125 loss : 0.379001 +[06:37:03.571] iteration 56900 [215.42 sec]: learning rate : 0.000125 loss : 17.621473 +[06:37:51.057] iteration 57000 [262.90 sec]: learning rate : 0.000125 loss : 0.466259 +[06:38:02.460] Epoch 98 Evaluation: +[06:39:29.295] average MSE: 0.05312002761437452 average PSNR: 25.78168426613071 average SSIM: 0.696189907588102 +[06:40:05.638] iteration 57100 [36.32 sec]: learning rate : 0.000125 loss : 0.311674 +[06:40:54.399] iteration 57200 [85.08 sec]: learning rate : 0.000125 loss : 0.395882 +[06:41:41.786] iteration 57300 [132.47 sec]: learning rate : 0.000125 loss : 0.388467 +[06:42:29.303] iteration 57400 [179.99 sec]: learning rate : 0.000125 loss : 0.409514 +[06:43:16.898] iteration 57500 [227.58 sec]: learning rate : 0.000125 loss : 0.370937 +[06:44:04.310] iteration 57600 [274.99 sec]: learning rate : 0.000125 loss : 0.357531 +[06:44:04.348] Epoch 99 Evaluation: +[06:45:31.401] average MSE: 0.052650239726243914 average PSNR: 25.82019509869482 average SSIM: 0.6990421559646904 +[06:46:19.720] iteration 57700 [48.29 sec]: learning rate : 0.000125 loss : 0.405735 +[06:47:07.384] iteration 57800 [95.96 sec]: learning rate : 0.000125 loss : 0.472252 +[06:47:55.170] iteration 57900 [143.74 sec]: learning rate : 0.000125 loss : 0.336409 +[06:48:43.178] iteration 58000 [191.75 sec]: learning rate : 0.000125 loss : 0.383808 +[06:49:30.687] iteration 58100 [239.26 sec]: learning rate : 0.000125 loss : 0.463809 +[06:50:06.730] Epoch 100 Evaluation: +[06:51:32.972] average MSE: 0.0511342198811734 average PSNR: 25.949863109700136 average SSIM: 0.7037307655119311 +[06:51:44.555] iteration 58200 [11.56 sec]: learning rate : 0.000125 loss : 0.441292 +[06:52:32.498] iteration 58300 [59.50 sec]: learning rate : 0.000125 loss : 0.589213 +[06:53:19.947] iteration 58400 [106.96 sec]: learning rate : 0.000125 loss : 0.388197 +[06:54:07.327] iteration 58500 [154.33 sec]: learning rate : 0.000125 loss : 0.339473 +[06:54:54.800] iteration 58600 [201.80 sec]: learning rate : 0.000125 loss : 0.444905 +[06:55:42.622] iteration 58700 [249.63 sec]: learning rate : 0.000125 loss : 0.452666 +[06:56:07.387] Epoch 101 Evaluation: +[06:57:33.553] average MSE: 0.05398568508893285 average PSNR: 25.710656838892398 average SSIM: 0.693730089535295 +[06:57:56.512] iteration 58800 [22.94 sec]: learning rate : 0.000125 loss : 0.328732 +[06:58:44.102] iteration 58900 [70.53 sec]: learning rate : 0.000125 loss : 0.434677 +[06:59:31.507] iteration 59000 [117.93 sec]: learning rate : 0.000125 loss : 0.424023 +[07:00:19.007] iteration 59100 [165.43 sec]: learning rate : 0.000125 loss : 0.458471 +[07:01:06.667] iteration 59200 [213.09 sec]: learning rate : 0.000125 loss : 0.338683 +[07:01:54.267] iteration 59300 [260.69 sec]: learning rate : 0.000125 loss : 0.370261 +[07:02:07.630] Epoch 102 Evaluation: +[07:03:36.961] average MSE: 0.053031972791954755 average PSNR: 25.791905647206367 average SSIM: 0.698217172139358 +[07:04:11.495] iteration 59400 [34.51 sec]: learning rate : 0.000125 loss : 0.369350 +[07:04:59.112] iteration 59500 [82.13 sec]: learning rate : 0.000125 loss : 0.374882 +[07:05:46.522] iteration 59600 [129.54 sec]: learning rate : 0.000125 loss : 0.395667 +[07:06:34.425] iteration 59700 [177.44 sec]: learning rate : 0.000125 loss : 0.330299 +[07:07:21.854] iteration 59800 [224.87 sec]: learning rate : 0.000125 loss : 0.329223 +[07:08:09.351] iteration 59900 [272.37 sec]: learning rate : 0.000125 loss : 0.363498 +[07:08:11.257] Epoch 103 Evaluation: +[07:09:38.381] average MSE: 0.05536897139253464 average PSNR: 25.603759969274595 average SSIM: 0.6909275506669023 +[07:10:24.261] iteration 60000 [45.86 sec]: learning rate : 0.000031 loss : 0.464557 +[07:10:24.424] save model to model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/iter_60000.pth +[07:11:12.296] iteration 60100 [93.89 sec]: learning rate : 0.000063 loss : 0.549483 +[07:11:59.840] iteration 60200 [141.44 sec]: learning rate : 0.000063 loss : 0.462175 +[07:12:47.368] iteration 60300 [188.96 sec]: learning rate : 0.000063 loss : 0.405148 +[07:13:34.898] iteration 60400 [236.49 sec]: learning rate : 0.000063 loss : 0.373908 +[07:14:13.045] Epoch 104 Evaluation: +[07:15:41.835] average MSE: 0.06674071889856192 average PSNR: 24.786322240545505 average SSIM: 0.6699198184527031 +[07:15:51.638] iteration 60500 [9.78 sec]: learning rate : 0.000063 loss : 0.339766 +[07:16:39.331] iteration 60600 [57.47 sec]: learning rate : 0.000063 loss : 0.361510 +[07:17:26.810] iteration 60700 [104.95 sec]: learning rate : 0.000063 loss : 0.428457 +[07:18:14.756] iteration 60800 [152.90 sec]: learning rate : 0.000063 loss : 0.530132 +[07:19:02.283] iteration 60900 [200.43 sec]: learning rate : 0.000063 loss : 0.516602 +[07:19:49.682] iteration 61000 [247.82 sec]: learning rate : 0.000063 loss : 0.438719 +[07:20:16.798] Epoch 105 Evaluation: +[07:21:44.109] average MSE: 0.06791618849899082 average PSNR: 24.71042895071977 average SSIM: 0.6680433111306943 +[07:22:05.250] iteration 61100 [21.12 sec]: learning rate : 0.000063 loss : 0.421836 +[07:22:52.768] iteration 61200 [68.64 sec]: learning rate : 0.000063 loss : 0.383560 +[07:23:40.743] iteration 61300 [116.61 sec]: learning rate : 0.000063 loss : 0.443851 +[07:24:28.458] iteration 61400 [164.33 sec]: learning rate : 0.000063 loss : 0.479040 +[07:25:16.050] iteration 61500 [211.92 sec]: learning rate : 0.000063 loss : 0.360193 +[07:26:04.507] iteration 61600 [260.38 sec]: learning rate : 0.000063 loss : 0.370536 +[07:26:19.699] Epoch 106 Evaluation: +[07:27:46.073] average MSE: 0.07317738927851715 average PSNR: 24.38509628143269 average SSIM: 0.6626636846664623 +[07:28:18.710] iteration 61700 [32.61 sec]: learning rate : 0.000063 loss : 0.406160 +[07:29:06.148] iteration 61800 [80.05 sec]: learning rate : 0.000063 loss : 0.303865 +[07:29:53.717] iteration 61900 [127.62 sec]: learning rate : 0.000063 loss : 0.435578 +[07:30:41.193] iteration 62000 [175.10 sec]: learning rate : 0.000063 loss : 0.440674 +[07:31:29.213] iteration 62100 [223.12 sec]: learning rate : 0.000063 loss : 0.483128 +[07:32:16.767] iteration 62200 [270.67 sec]: learning rate : 0.000063 loss : 0.421714 +[07:32:20.579] Epoch 107 Evaluation: +[07:33:51.072] average MSE: 0.07566363563085217 average PSNR: 24.24179431826789 average SSIM: 0.6540678588223667 +[07:34:34.985] iteration 62300 [43.89 sec]: learning rate : 0.000063 loss : 0.263800 +[07:35:22.676] iteration 62400 [91.58 sec]: learning rate : 0.000063 loss : 0.397773 +[07:36:10.336] iteration 62500 [139.24 sec]: learning rate : 0.000063 loss : 0.486722 +[07:36:58.303] iteration 62600 [187.21 sec]: learning rate : 0.000063 loss : 0.323624 +[07:37:45.953] iteration 62700 [234.86 sec]: learning rate : 0.000063 loss : 0.308550 +[07:38:25.965] Epoch 108 Evaluation: +[07:39:52.290] average MSE: 0.08544215210246435 average PSNR: 23.71765482007473 average SSIM: 0.6411319279686032 +[07:40:00.079] iteration 62800 [7.77 sec]: learning rate : 0.000063 loss : 0.408218 +[07:40:47.867] iteration 62900 [55.55 sec]: learning rate : 0.000063 loss : 0.344436 +[07:41:35.671] iteration 63000 [103.36 sec]: learning rate : 0.000063 loss : 0.476835 +[07:42:23.582] iteration 63100 [151.27 sec]: learning rate : 0.000063 loss : 0.410108 +[07:43:11.097] iteration 63200 [198.79 sec]: learning rate : 0.000063 loss : 0.378362 +[07:43:58.770] iteration 63300 [246.46 sec]: learning rate : 0.000063 loss : 0.418679 +[07:44:27.366] Epoch 109 Evaluation: +[07:45:57.793] average MSE: 0.092951508873772 average PSNR: 23.344960051548718 average SSIM: 0.6293584831010125 +[07:46:17.100] iteration 63400 [19.28 sec]: learning rate : 0.000063 loss : 0.445147 +[07:47:04.517] iteration 63500 [66.70 sec]: learning rate : 0.000063 loss : 0.510166 +[07:47:52.654] iteration 63600 [114.84 sec]: learning rate : 0.000063 loss : 0.419645 +[07:48:40.493] iteration 63700 [162.68 sec]: learning rate : 0.000063 loss : 0.383112 +[07:49:27.993] iteration 63800 [210.18 sec]: learning rate : 0.000063 loss : 0.435966 +[07:50:15.479] iteration 63900 [257.66 sec]: learning rate : 0.000063 loss : 0.377602 +[07:50:32.925] Epoch 110 Evaluation: +[07:51:58.922] average MSE: 0.10434068965680815 average PSNR: 22.850003737777996 average SSIM: 0.6100088319582084 +[07:52:29.659] iteration 64000 [30.71 sec]: learning rate : 0.000063 loss : 0.420914 +[07:53:17.379] iteration 64100 [78.43 sec]: learning rate : 0.000063 loss : 0.479779 +[07:54:05.619] iteration 64200 [126.67 sec]: learning rate : 0.000063 loss : 0.486305 +[07:54:53.149] iteration 64300 [174.20 sec]: learning rate : 0.000063 loss : 0.520115 +[07:55:40.810] iteration 64400 [221.86 sec]: learning rate : 0.000063 loss : 0.487530 +[07:56:29.061] iteration 64500 [270.12 sec]: learning rate : 0.000063 loss : 0.384842 +[07:56:34.785] Epoch 111 Evaluation: +[07:58:04.976] average MSE: 0.11702566749150281 average PSNR: 22.35575578667725 average SSIM: 0.5961862071499989 +[07:58:46.975] iteration 64600 [41.98 sec]: learning rate : 0.000063 loss : 0.435859 +[07:59:35.539] iteration 64700 [90.54 sec]: learning rate : 0.000063 loss : 0.317505 +[08:00:23.052] iteration 64800 [138.05 sec]: learning rate : 0.000063 loss : 0.502473 +[08:01:10.581] iteration 64900 [185.58 sec]: learning rate : 0.000063 loss : 0.344848 +[08:01:58.068] iteration 65000 [233.07 sec]: learning rate : 0.000063 loss : 0.461914 +[08:02:39.763] Epoch 112 Evaluation: +[08:04:06.710] average MSE: 0.12442987829695827 average PSNR: 22.091906040111624 average SSIM: 0.5913309853219723 +[08:04:12.635] iteration 65100 [5.90 sec]: learning rate : 0.000063 loss : 0.313232 +[08:05:00.931] iteration 65200 [54.20 sec]: learning rate : 0.000063 loss : 0.337380 +[08:05:48.423] iteration 65300 [101.69 sec]: learning rate : 0.000063 loss : 0.363168 +[08:06:35.824] iteration 65400 [149.09 sec]: learning rate : 0.000063 loss : 0.336397 +[08:07:23.343] iteration 65500 [196.61 sec]: learning rate : 0.000063 loss : 0.338446 +[08:08:10.840] iteration 65600 [244.11 sec]: learning rate : 0.000063 loss : 0.428988 +[08:08:41.569] Epoch 113 Evaluation: +[08:10:09.409] average MSE: 0.11206499634101852 average PSNR: 22.544519784166305 average SSIM: 0.5893774388968501 +[08:10:26.675] iteration 65700 [17.24 sec]: learning rate : 0.000063 loss : 0.408425 +[08:11:15.006] iteration 65800 [65.57 sec]: learning rate : 0.000063 loss : 0.487410 +[08:12:02.441] iteration 65900 [113.02 sec]: learning rate : 0.000063 loss : 0.435236 +[08:12:50.137] iteration 66000 [160.71 sec]: learning rate : 0.000063 loss : 0.364924 +[08:13:37.735] iteration 66100 [208.30 sec]: learning rate : 0.000063 loss : 0.457170 +[08:14:25.266] iteration 66200 [255.83 sec]: learning rate : 0.000063 loss : 0.372613 +[08:14:44.369] Epoch 114 Evaluation: +[08:16:11.522] average MSE: 0.06909053809148488 average PSNR: 24.647169222334252 average SSIM: 0.6649611142322535 +[08:16:40.162] iteration 66300 [28.62 sec]: learning rate : 0.000063 loss : 0.447624 +[08:17:28.145] iteration 66400 [76.60 sec]: learning rate : 0.000063 loss : 0.475092 +[08:18:15.835] iteration 66500 [124.29 sec]: learning rate : 0.000063 loss : 0.424600 +[08:19:03.296] iteration 66600 [171.75 sec]: learning rate : 0.000063 loss : 0.581109 +[08:19:50.883] iteration 66700 [219.34 sec]: learning rate : 0.000063 loss : 0.390202 +[08:20:38.382] iteration 66800 [266.84 sec]: learning rate : 0.000063 loss : 0.303917 +[08:20:45.992] Epoch 115 Evaluation: +[08:22:15.464] average MSE: 0.07708600393387355 average PSNR: 24.18455629039276 average SSIM: 0.6554967783631752 +[08:22:55.712] iteration 66900 [40.22 sec]: learning rate : 0.000063 loss : 0.344874 +[08:23:43.077] iteration 67000 [87.59 sec]: learning rate : 0.000063 loss : 0.480162 +[08:24:30.531] iteration 67100 [135.04 sec]: learning rate : 0.000063 loss : 0.430927 +[08:25:18.022] iteration 67200 [182.53 sec]: learning rate : 0.000063 loss : 0.427506 +[08:26:05.869] iteration 67300 [230.38 sec]: learning rate : 0.000063 loss : 0.382883 +[08:26:50.152] Epoch 116 Evaluation: +[08:28:18.866] average MSE: 0.11029201407471707 average PSNR: 22.624758600575007 average SSIM: 0.5974928517284825 +[08:28:22.876] iteration 67400 [3.99 sec]: learning rate : 0.000063 loss : 0.479878 +[08:29:10.578] iteration 67500 [51.69 sec]: learning rate : 0.000063 loss : 0.476237 +[08:29:58.086] iteration 67600 [99.20 sec]: learning rate : 0.000063 loss : 0.331223 +[08:30:45.634] iteration 67700 [146.75 sec]: learning rate : 0.000063 loss : 0.450038 +[08:31:33.273] iteration 67800 [194.38 sec]: learning rate : 0.000063 loss : 0.343870 +[08:32:20.677] iteration 67900 [241.79 sec]: learning rate : 0.000063 loss : 0.418972 +[08:32:53.404] Epoch 117 Evaluation: +[08:34:24.601] average MSE: 0.14184986002734076 average PSNR: 21.5301625659181 average SSIM: 0.5676497809656398 +[08:34:40.070] iteration 68000 [15.45 sec]: learning rate : 0.000063 loss : 0.410920 +[08:35:28.170] iteration 68100 [63.55 sec]: learning rate : 0.000063 loss : 0.338255 +[08:36:15.784] iteration 68200 [111.16 sec]: learning rate : 0.000063 loss : 0.547057 +[08:37:03.329] iteration 68300 [158.71 sec]: learning rate : 0.000063 loss : 0.488108 +[08:37:50.830] iteration 68400 [206.21 sec]: learning rate : 0.000063 loss : 0.428033 +[08:38:38.924] iteration 68500 [254.30 sec]: learning rate : 0.000063 loss : 0.481096 +[08:38:59.842] Epoch 118 Evaluation: +[08:40:29.084] average MSE: 0.10286654952470824 average PSNR: 22.921233054054625 average SSIM: 0.599842755233686 +[08:40:56.596] iteration 68600 [27.49 sec]: learning rate : 0.000063 loss : 0.406046 +[08:41:44.092] iteration 68700 [74.98 sec]: learning rate : 0.000063 loss : 0.393894 +[08:42:31.674] iteration 68800 [122.57 sec]: learning rate : 0.000063 loss : 0.374362 +[08:43:19.269] iteration 68900 [170.16 sec]: learning rate : 0.000063 loss : 0.469872 +[08:44:07.319] iteration 69000 [218.21 sec]: learning rate : 0.000063 loss : 0.557120 +[08:44:54.829] iteration 69100 [265.72 sec]: learning rate : 0.000063 loss : 0.572905 +[08:45:04.341] Epoch 119 Evaluation: +[08:46:31.091] average MSE: 0.13804917316355741 average PSNR: 21.63078541778945 average SSIM: 0.5861871520480639 +[08:47:09.484] iteration 69200 [38.37 sec]: learning rate : 0.000063 loss : 0.321090 +[08:47:57.190] iteration 69300 [86.08 sec]: learning rate : 0.000063 loss : 0.389963 +[08:48:45.040] iteration 69400 [133.92 sec]: learning rate : 0.000063 loss : 0.404854 +[08:49:32.904] iteration 69500 [181.79 sec]: learning rate : 0.000063 loss : 0.376185 +[08:50:20.388] iteration 69600 [229.27 sec]: learning rate : 0.000063 loss : 0.408432 +[08:51:06.033] Epoch 120 Evaluation: +[08:52:32.205] average MSE: 0.09035843684535477 average PSNR: 23.510812581572218 average SSIM: 0.6341677017940338 +[08:52:34.315] iteration 69700 [2.09 sec]: learning rate : 0.000063 loss : 0.325733 +[08:53:22.144] iteration 69800 [49.92 sec]: learning rate : 0.000063 loss : 0.419548 +[08:54:09.643] iteration 69900 [97.41 sec]: learning rate : 0.000063 loss : 0.398569 +[08:54:57.537] iteration 70000 [145.31 sec]: learning rate : 0.000063 loss : 0.465659 +[08:55:45.326] iteration 70100 [193.10 sec]: learning rate : 0.000063 loss : 0.322859 +[08:56:32.825] iteration 70200 [240.60 sec]: learning rate : 0.000063 loss : 0.399992 +[08:57:06.955] Epoch 121 Evaluation: +[08:58:32.695] average MSE: 0.10041219576872687 average PSNR: 23.027077436771872 average SSIM: 0.6108187713124608 +[08:58:46.169] iteration 70300 [13.45 sec]: learning rate : 0.000063 loss : 0.416058 +[08:59:33.718] iteration 70400 [61.00 sec]: learning rate : 0.000063 loss : 0.391837 +[09:00:21.155] iteration 70500 [108.44 sec]: learning rate : 0.000063 loss : 0.407948 +[09:01:09.027] iteration 70600 [156.31 sec]: learning rate : 0.000063 loss : 0.487194 +[09:01:56.956] iteration 70700 [204.24 sec]: learning rate : 0.000063 loss : 0.361866 +[09:02:44.534] iteration 70800 [251.82 sec]: learning rate : 0.000063 loss : 0.467753 +[09:03:07.732] Epoch 122 Evaluation: +[09:04:34.122] average MSE: 0.12918005450374517 average PSNR: 21.928011200279162 average SSIM: 0.580232456563205 +[09:04:59.044] iteration 70900 [24.90 sec]: learning rate : 0.000063 loss : 0.396948 +[09:05:46.600] iteration 71000 [72.45 sec]: learning rate : 0.000063 loss : 0.350258 +[09:06:34.707] iteration 71100 [120.56 sec]: learning rate : 0.000063 loss : 0.453464 +[09:07:22.316] iteration 71200 [168.17 sec]: learning rate : 0.000063 loss : 0.384755 +[09:08:10.040] iteration 71300 [215.89 sec]: learning rate : 0.000063 loss : 0.550388 +[09:08:57.670] iteration 71400 [263.52 sec]: learning rate : 0.000063 loss : 0.326244 +[09:09:09.093] Epoch 123 Evaluation: +[09:10:39.829] average MSE: 0.154144088048766 average PSNR: 21.154790588049337 average SSIM: 0.5604438161024102 +[09:11:16.201] iteration 71500 [36.35 sec]: learning rate : 0.000063 loss : 0.518151 +[09:12:04.102] iteration 71600 [84.25 sec]: learning rate : 0.000063 loss : 0.403332 +[09:12:51.461] iteration 71700 [131.61 sec]: learning rate : 0.000063 loss : 0.325240 +[09:13:38.949] iteration 71800 [179.10 sec]: learning rate : 0.000063 loss : 0.423178 +[09:14:26.561] iteration 71900 [226.71 sec]: learning rate : 0.000063 loss : 0.389483 +[09:15:14.055] iteration 72000 [274.20 sec]: learning rate : 0.000063 loss : 0.376485 +[09:15:14.102] Epoch 124 Evaluation: +[09:16:40.505] average MSE: 0.1792097695583279 average PSNR: 20.510036727103614 average SSIM: 0.5541657798938253 +[09:17:28.927] iteration 72100 [48.40 sec]: learning rate : 0.000063 loss : 0.272848 +[09:18:16.732] iteration 72200 [96.20 sec]: learning rate : 0.000063 loss : 0.459931 +[09:19:04.088] iteration 72300 [143.56 sec]: learning rate : 0.000063 loss : 0.378906 +[09:19:51.808] iteration 72400 [191.28 sec]: learning rate : 0.000063 loss : 0.361388 +[09:20:39.479] iteration 72500 [238.95 sec]: learning rate : 0.000063 loss : 0.403864 +[09:21:15.513] Epoch 125 Evaluation: +[09:22:47.249] average MSE: 0.1632176964544554 average PSNR: 20.90982221549066 average SSIM: 0.5624769700580672 +[09:22:58.861] iteration 72600 [11.59 sec]: learning rate : 0.000063 loss : 0.373886 +[09:23:46.524] iteration 72700 [59.25 sec]: learning rate : 0.000063 loss : 0.336666 +[09:24:33.909] iteration 72800 [106.64 sec]: learning rate : 0.000063 loss : 0.513353 +[09:25:21.386] iteration 72900 [154.11 sec]: learning rate : 0.000063 loss : 0.416968 +[09:26:09.379] iteration 73000 [202.11 sec]: learning rate : 0.000063 loss : 0.501786 +[09:26:56.815] iteration 73100 [249.54 sec]: learning rate : 0.000063 loss : 0.461485 +[09:27:21.673] Epoch 126 Evaluation: +[09:28:50.030] average MSE: 0.13058104360408113 average PSNR: 21.881459318418354 average SSIM: 0.5808308611003685 +[09:29:12.982] iteration 73200 [22.93 sec]: learning rate : 0.000063 loss : 0.395995 +[09:30:00.617] iteration 73300 [70.57 sec]: learning rate : 0.000063 loss : 0.323763 +[09:30:48.113] iteration 73400 [118.06 sec]: learning rate : 0.000063 loss : 0.520040 +[09:31:35.700] iteration 73500 [165.65 sec]: learning rate : 0.000063 loss : 0.408539 +[09:32:23.372] iteration 73600 [213.32 sec]: learning rate : 0.000063 loss : 0.351782 +[09:33:10.892] iteration 73700 [260.84 sec]: learning rate : 0.000063 loss : 0.268895 +[09:33:25.040] Epoch 127 Evaluation: +[09:34:55.962] average MSE: 0.13605737951500071 average PSNR: 21.697862604469478 average SSIM: 0.5838809562440193 +[09:35:30.603] iteration 73800 [34.62 sec]: learning rate : 0.000063 loss : 0.370083 +[09:36:18.153] iteration 73900 [82.17 sec]: learning rate : 0.000063 loss : 0.379371 +[09:37:05.813] iteration 74000 [129.83 sec]: learning rate : 0.000063 loss : 0.442067 +[09:37:53.881] iteration 74100 [177.90 sec]: learning rate : 0.000063 loss : 0.366402 +[09:38:41.440] iteration 74200 [225.46 sec]: learning rate : 0.000063 loss : 0.426516 +[09:39:29.060] iteration 74300 [273.08 sec]: learning rate : 0.000063 loss : 0.285528 +[09:39:31.237] Epoch 128 Evaluation: +[09:40:58.563] average MSE: 0.14236657225351063 average PSNR: 21.505982640580477 average SSIM: 0.5741002347172939 +[09:41:44.420] iteration 74400 [45.83 sec]: learning rate : 0.000063 loss : 0.498098 +[09:42:31.801] iteration 74500 [93.22 sec]: learning rate : 0.000063 loss : 0.542931 +[09:43:19.300] iteration 74600 [140.72 sec]: learning rate : 0.000063 loss : 0.457950 +[09:44:06.801] iteration 74700 [188.22 sec]: learning rate : 0.000063 loss : 0.331165 +[09:44:54.195] iteration 74800 [235.61 sec]: learning rate : 0.000063 loss : 0.403616 +[09:45:32.640] Epoch 129 Evaluation: +[09:46:59.304] average MSE: 0.1526183482007697 average PSNR: 21.202714207222904 average SSIM: 0.5691590839734006 +[09:47:08.988] iteration 74900 [9.66 sec]: learning rate : 0.000063 loss : 0.371631 +[09:47:56.625] iteration 75000 [57.30 sec]: learning rate : 0.000063 loss : 0.358966 +[09:48:44.302] iteration 75100 [104.98 sec]: learning rate : 0.000063 loss : 0.485399 +[09:49:31.771] iteration 75200 [152.45 sec]: learning rate : 0.000063 loss : 0.459925 +[09:50:19.171] iteration 75300 [199.85 sec]: learning rate : 0.000063 loss : 0.437019 +[09:51:07.070] iteration 75400 [247.74 sec]: learning rate : 0.000063 loss : 0.432755 +[09:51:33.639] Epoch 130 Evaluation: +[09:52:59.795] average MSE: 0.17937353879334608 average PSNR: 20.51795234903749 average SSIM: 0.5617818964545189 +[09:53:21.022] iteration 75500 [21.20 sec]: learning rate : 0.000063 loss : 0.533761 +[09:54:08.419] iteration 75600 [68.60 sec]: learning rate : 0.000063 loss : 0.380058 +[09:54:55.884] iteration 75700 [116.07 sec]: learning rate : 0.000063 loss : 0.394642 +[09:55:44.278] iteration 75800 [164.46 sec]: learning rate : 0.000063 loss : 0.489339 +[09:56:32.369] iteration 75900 [212.55 sec]: learning rate : 0.000063 loss : 0.438549 +[09:57:19.935] iteration 76000 [260.12 sec]: learning rate : 0.000063 loss : 0.460096 +[09:57:35.150] Epoch 131 Evaluation: +[09:59:01.763] average MSE: 0.15151498408420988 average PSNR: 21.238593291575246 average SSIM: 0.5760412619007543 +[09:59:34.375] iteration 76100 [32.59 sec]: learning rate : 0.000063 loss : 0.286089 +[10:00:22.130] iteration 76200 [80.34 sec]: learning rate : 0.000063 loss : 0.226616 +[10:01:09.786] iteration 76300 [128.00 sec]: learning rate : 0.000063 loss : 0.366477 +[10:01:57.718] iteration 76400 [175.93 sec]: learning rate : 0.000063 loss : 0.375805 +[10:02:45.313] iteration 76500 [223.53 sec]: learning rate : 0.000063 loss : 0.473951 +[10:03:33.303] iteration 76600 [271.52 sec]: learning rate : 0.000063 loss : 0.445234 +[10:03:37.110] Epoch 132 Evaluation: +[10:05:06.921] average MSE: 0.13485695676310847 average PSNR: 21.739494618376956 average SSIM: 0.578986537963838 +[10:05:50.999] iteration 76700 [44.06 sec]: learning rate : 0.000063 loss : 0.300211 +[10:06:38.752] iteration 76800 [91.81 sec]: learning rate : 0.000063 loss : 0.412661 +[10:07:26.742] iteration 76900 [139.80 sec]: learning rate : 0.000063 loss : 0.417800 +[10:08:14.174] iteration 77000 [187.23 sec]: learning rate : 0.000063 loss : 0.485991 +[10:09:01.755] iteration 77100 [234.81 sec]: learning rate : 0.000063 loss : 0.279820 +[10:09:41.658] Epoch 133 Evaluation: +[10:11:09.412] average MSE: 0.16759465443736604 average PSNR: 20.789583286935667 average SSIM: 0.5520202614442705 +[10:11:17.231] iteration 77200 [7.80 sec]: learning rate : 0.000063 loss : 0.428330 +[10:12:04.971] iteration 77300 [55.54 sec]: learning rate : 0.000063 loss : 0.384217 +[10:12:52.449] iteration 77400 [103.01 sec]: learning rate : 0.000063 loss : 0.451235 +[10:13:40.530] iteration 77500 [151.09 sec]: learning rate : 0.000063 loss : 0.405939 +[10:14:28.154] iteration 77600 [198.72 sec]: learning rate : 0.000063 loss : 0.402043 +[10:15:15.706] iteration 77700 [246.27 sec]: learning rate : 0.000063 loss : 0.391671 +[10:15:44.186] Epoch 134 Evaluation: +[10:17:11.026] average MSE: 0.15736514884330216 average PSNR: 21.062674144077892 average SSIM: 0.5624861337092445 +[10:17:30.270] iteration 77800 [19.22 sec]: learning rate : 0.000063 loss : 0.314850 +[10:18:17.995] iteration 77900 [66.96 sec]: learning rate : 0.000063 loss : 0.584513 +[10:19:06.226] iteration 78000 [115.18 sec]: learning rate : 0.000063 loss : 0.419929 +[10:19:53.720] iteration 78100 [162.67 sec]: learning rate : 0.000063 loss : 0.391526 +[10:20:41.360] iteration 78200 [210.31 sec]: learning rate : 0.000063 loss : 0.423350 +[10:21:28.761] iteration 78300 [257.71 sec]: learning rate : 0.000063 loss : 0.489596 +[10:21:45.934] Epoch 135 Evaluation: +[10:23:12.518] average MSE: 0.10273389321827223 average PSNR: 22.911398207041287 average SSIM: 0.611735139221572 +[10:23:43.062] iteration 78400 [30.52 sec]: learning rate : 0.000063 loss : 0.399681 +[10:24:30.999] iteration 78500 [78.46 sec]: learning rate : 0.000063 loss : 0.457056 +[10:25:18.379] iteration 78600 [125.84 sec]: learning rate : 0.000063 loss : 0.538921 +[10:26:06.265] iteration 78700 [173.72 sec]: learning rate : 0.000063 loss : 0.390287 +[10:26:53.870] iteration 78800 [221.33 sec]: learning rate : 0.000063 loss : 0.527601 +[10:27:41.294] iteration 78900 [268.75 sec]: learning rate : 0.000063 loss : 0.429637 +[10:27:47.003] Epoch 136 Evaluation: +[10:29:14.632] average MSE: 0.15518326868421087 average PSNR: 21.133801129744445 average SSIM: 0.5619964060981958 +[10:29:57.001] iteration 79000 [42.35 sec]: learning rate : 0.000063 loss : 0.438572 +[10:30:44.665] iteration 79100 [90.01 sec]: learning rate : 0.000063 loss : 0.308022 +[10:31:32.708] iteration 79200 [138.05 sec]: learning rate : 0.000063 loss : 0.348169 +[10:32:20.389] iteration 79300 [185.74 sec]: learning rate : 0.000063 loss : 0.285727 +[10:33:08.006] iteration 79400 [233.35 sec]: learning rate : 0.000063 loss : 0.380207 +[10:33:50.183] Epoch 137 Evaluation: +[10:35:17.802] average MSE: 0.12547630164506637 average PSNR: 22.054255720339626 average SSIM: 0.5851438641001281 +[10:35:23.919] iteration 79500 [6.09 sec]: learning rate : 0.000063 loss : 0.338381 +[10:36:11.552] iteration 79600 [53.73 sec]: learning rate : 0.000063 loss : 0.441189 +[10:36:58.960] iteration 79700 [101.13 sec]: learning rate : 0.000063 loss : 0.412141 +[10:37:46.458] iteration 79800 [148.63 sec]: learning rate : 0.000063 loss : 0.427986 +[10:38:33.949] iteration 79900 [196.12 sec]: learning rate : 0.000063 loss : 0.483481 +[10:39:21.367] iteration 80000 [243.54 sec]: learning rate : 0.000016 loss : 0.378121 +[10:39:21.525] save model to model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/iter_80000.pth +[10:39:51.975] Epoch 138 Evaluation: +[10:41:19.602] average MSE: 0.14564063425701626 average PSNR: 21.405725019546843 average SSIM: 0.5702128812987797 +[10:41:36.916] iteration 80100 [17.29 sec]: learning rate : 0.000031 loss : 0.447788 +[10:42:24.595] iteration 80200 [64.97 sec]: learning rate : 0.000031 loss : 0.382978 +[10:43:12.074] iteration 80300 [112.45 sec]: learning rate : 0.000031 loss : 0.369168 +[10:43:59.647] iteration 80400 [160.02 sec]: learning rate : 0.000031 loss : 0.380868 +[10:44:47.227] iteration 80500 [207.60 sec]: learning rate : 0.000031 loss : 0.383801 +[10:45:34.714] iteration 80600 [255.09 sec]: learning rate : 0.000031 loss : 0.312683 +[10:45:53.723] Epoch 139 Evaluation: +[10:47:23.221] average MSE: 0.1248540115591993 average PSNR: 22.067110486181193 average SSIM: 0.5846028911009264 +[10:47:52.039] iteration 80700 [28.79 sec]: learning rate : 0.000031 loss : 0.491659 +[10:48:39.725] iteration 80800 [76.48 sec]: learning rate : 0.000031 loss : 0.487718 +[10:49:27.691] iteration 80900 [124.45 sec]: learning rate : 0.000031 loss : 0.393464 +[10:50:15.367] iteration 81000 [172.12 sec]: learning rate : 0.000031 loss : 0.507073 +[10:51:02.966] iteration 81100 [219.72 sec]: learning rate : 0.000031 loss : 0.434414 +[10:51:50.566] iteration 81200 [267.32 sec]: learning rate : 0.000031 loss : 0.321487 +[10:51:58.163] Epoch 140 Evaluation: +[10:53:27.547] average MSE: 0.14101567836975812 average PSNR: 21.546590774210458 average SSIM: 0.5665858836071436 +[10:54:07.795] iteration 81300 [40.23 sec]: learning rate : 0.000031 loss : 0.321249 +[10:54:55.309] iteration 81400 [87.74 sec]: learning rate : 0.000031 loss : 0.425609 +[10:55:42.802] iteration 81500 [135.23 sec]: learning rate : 0.000031 loss : 0.491157 +[10:56:30.794] iteration 81600 [183.23 sec]: learning rate : 0.000031 loss : 0.428694 +[10:57:18.258] iteration 81700 [230.69 sec]: learning rate : 0.000031 loss : 0.445679 +[10:58:03.037] Epoch 141 Evaluation: +[10:59:30.055] average MSE: 0.1496259012131731 average PSNR: 21.283726679625843 average SSIM: 0.566314548705725 +[10:59:34.081] iteration 81800 [4.00 sec]: learning rate : 0.000031 loss : 0.567985 +[11:00:21.662] iteration 81900 [51.58 sec]: learning rate : 0.000031 loss : 0.375851 +[11:01:09.050] iteration 82000 [98.97 sec]: learning rate : 0.000031 loss : 0.374564 +[11:01:56.515] iteration 82100 [146.44 sec]: learning rate : 0.000031 loss : 0.438534 +[11:02:43.921] iteration 82200 [193.84 sec]: learning rate : 0.000031 loss : 0.353274 +[11:03:32.385] iteration 82300 [242.31 sec]: learning rate : 0.000031 loss : 0.406703 +[11:04:04.685] Epoch 142 Evaluation: +[11:05:32.689] average MSE: 0.13753312968231998 average PSNR: 21.65356914720465 average SSIM: 0.5727478438787054 +[11:05:48.213] iteration 82400 [15.50 sec]: learning rate : 0.000031 loss : 0.445052 +[11:06:35.615] iteration 82500 [62.90 sec]: learning rate : 0.000031 loss : 0.386114 +[11:07:23.517] iteration 82600 [110.81 sec]: learning rate : 0.000031 loss : 0.446566 +[11:08:11.189] iteration 82700 [158.48 sec]: learning rate : 0.000031 loss : 0.519053 +[11:08:59.061] iteration 82800 [206.35 sec]: learning rate : 0.000031 loss : 0.459710 +[11:09:46.562] iteration 82900 [253.85 sec]: learning rate : 0.000031 loss : 0.334705 +[11:10:07.439] Epoch 143 Evaluation: +[11:11:34.531] average MSE: 0.14884333970364383 average PSNR: 21.312753535048476 average SSIM: 0.5707536548784765 +[11:12:01.437] iteration 83000 [26.88 sec]: learning rate : 0.000031 loss : 0.369635 +[11:12:49.091] iteration 83100 [74.54 sec]: learning rate : 0.000031 loss : 0.484270 +[11:13:36.790] iteration 83200 [122.24 sec]: learning rate : 0.000031 loss : 0.392544 +[11:14:24.379] iteration 83300 [169.85 sec]: learning rate : 0.000031 loss : 0.542263 +[11:15:12.418] iteration 83400 [217.86 sec]: learning rate : 0.000031 loss : 0.375546 +[11:16:00.368] iteration 83500 [265.81 sec]: learning rate : 0.000031 loss : 0.512947 +[11:16:09.864] Epoch 144 Evaluation: +[11:17:35.986] average MSE: 0.1360459859747244 average PSNR: 21.695491491717572 average SSIM: 0.5751990907486957 +[11:18:14.142] iteration 83600 [38.13 sec]: learning rate : 0.000031 loss : 0.465524 +[11:19:02.330] iteration 83700 [86.32 sec]: learning rate : 0.000031 loss : 0.457250 +[11:19:49.944] iteration 83800 [133.93 sec]: learning rate : 0.000031 loss : 0.375646 +[11:20:37.543] iteration 83900 [181.53 sec]: learning rate : 0.000031 loss : 0.489109 +[11:21:25.073] iteration 84000 [229.06 sec]: learning rate : 0.000031 loss : 0.460915 +[11:22:10.601] Epoch 145 Evaluation: +[11:23:39.003] average MSE: 0.13416934777642772 average PSNR: 21.76267950710396 average SSIM: 0.5792442413179516 +[11:23:41.133] iteration 84100 [2.10 sec]: learning rate : 0.000031 loss : 0.403897 +[11:24:28.750] iteration 84200 [49.72 sec]: learning rate : 0.000031 loss : 0.383289 +[11:25:16.627] iteration 84300 [97.60 sec]: learning rate : 0.000031 loss : 0.401055 +[11:26:04.866] iteration 84400 [145.84 sec]: learning rate : 0.000031 loss : 0.462616 +[11:26:52.533] iteration 84500 [193.50 sec]: learning rate : 0.000031 loss : 0.369037 +[11:27:40.241] iteration 84600 [241.21 sec]: learning rate : 0.000031 loss : 0.346444 +[11:28:14.487] Epoch 146 Evaluation: +[11:29:40.970] average MSE: 0.13357091850726832 average PSNR: 21.78218127922643 average SSIM: 0.5780060282936852 +[11:29:54.457] iteration 84700 [13.46 sec]: learning rate : 0.000031 loss : 0.353590 +[11:30:42.091] iteration 84800 [61.10 sec]: learning rate : 0.000031 loss : 0.413361 +[11:31:30.064] iteration 84900 [109.07 sec]: learning rate : 0.000031 loss : 0.416433 +[11:32:17.625] iteration 85000 [156.63 sec]: learning rate : 0.000031 loss : 0.529643 +[11:33:05.287] iteration 85100 [204.29 sec]: learning rate : 0.000031 loss : 0.395755 +[11:33:53.727] iteration 85200 [252.73 sec]: learning rate : 0.000031 loss : 0.484762 +[11:34:16.567] Epoch 147 Evaluation: +[11:35:42.629] average MSE: 0.11836813317830569 average PSNR: 22.302628996054324 average SSIM: 0.5937110849235001 +[11:36:07.475] iteration 85300 [24.82 sec]: learning rate : 0.000031 loss : 0.384616 +[11:36:55.531] iteration 85400 [72.88 sec]: learning rate : 0.000031 loss : 0.405507 +[11:37:42.913] iteration 85500 [120.26 sec]: learning rate : 0.000031 loss : 0.358863 +[11:38:30.402] iteration 85600 [167.75 sec]: learning rate : 0.000031 loss : 0.394005 +[11:39:17.931] iteration 85700 [215.28 sec]: learning rate : 0.000031 loss : 0.382657 +[11:40:05.439] iteration 85800 [262.79 sec]: learning rate : 0.000031 loss : 0.415647 +[11:40:16.833] Epoch 148 Evaluation: +[11:41:43.555] average MSE: 0.1260082720867376 average PSNR: 22.037583667581973 average SSIM: 0.5867718485344342 +[11:42:20.582] iteration 85900 [37.02 sec]: learning rate : 0.000031 loss : 0.357781 +[11:43:08.110] iteration 86000 [84.53 sec]: learning rate : 0.000031 loss : 0.364491 +[11:43:55.507] iteration 86100 [131.93 sec]: learning rate : 0.000031 loss : 0.350592 +[11:44:43.020] iteration 86200 [179.44 sec]: learning rate : 0.000031 loss : 0.423472 +[11:45:30.438] iteration 86300 [226.86 sec]: learning rate : 0.000031 loss : 0.437623 +[11:46:17.938] iteration 86400 [274.36 sec]: learning rate : 0.000031 loss : 0.326082 +[11:46:17.976] Epoch 149 Evaluation: +[11:47:44.157] average MSE: 0.15451576011785795 average PSNR: 21.160884432343334 average SSIM: 0.5692850310021247 +[11:48:32.944] iteration 86500 [48.76 sec]: learning rate : 0.000031 loss : 0.368197 +[11:49:20.519] iteration 86600 [96.34 sec]: learning rate : 0.000031 loss : 0.425389 +[11:50:08.241] iteration 86700 [144.06 sec]: learning rate : 0.000031 loss : 0.394460 +[11:50:55.938] iteration 86800 [191.76 sec]: learning rate : 0.000031 loss : 0.376207 +[11:51:43.870] iteration 86900 [239.69 sec]: learning rate : 0.000031 loss : 0.376427 +[11:52:19.991] Epoch 150 Evaluation: +[11:53:46.864] average MSE: 0.13520812306476754 average PSNR: 21.733359495926884 average SSIM: 0.5800078428174403 +[11:53:58.488] iteration 87000 [11.60 sec]: learning rate : 0.000031 loss : 0.384288 +[11:54:46.030] iteration 87100 [59.14 sec]: learning rate : 0.000031 loss : 0.408698 +[11:55:33.380] iteration 87200 [106.49 sec]: learning rate : 0.000031 loss : 0.415666 +[11:56:21.288] iteration 87300 [154.40 sec]: learning rate : 0.000031 loss : 4.685630 +[11:57:08.667] iteration 87400 [201.78 sec]: learning rate : 0.000031 loss : 0.381736 +[11:57:56.318] iteration 87500 [249.43 sec]: learning rate : 0.000031 loss : 0.453628 +[11:58:21.068] Epoch 151 Evaluation: +[11:59:50.431] average MSE: 0.1328622270014226 average PSNR: 21.813560889609928 average SSIM: 0.5797018025824321 +[12:00:13.972] iteration 87600 [23.52 sec]: learning rate : 0.000031 loss : 0.405025 +[12:01:01.602] iteration 87700 [71.15 sec]: learning rate : 0.000031 loss : 0.392582 +[12:01:49.360] iteration 87800 [118.91 sec]: learning rate : 0.000031 loss : 0.424055 +[12:02:37.122] iteration 87900 [166.67 sec]: learning rate : 0.000031 loss : 0.383232 +[12:03:24.689] iteration 88000 [214.26 sec]: learning rate : 0.000031 loss : 0.409515 +[12:04:12.633] iteration 88100 [262.18 sec]: learning rate : 0.000031 loss : 0.267602 +[12:04:25.953] Epoch 152 Evaluation: +[12:05:52.468] average MSE: 0.1237602075708357 average PSNR: 22.11620081907275 average SSIM: 0.5894690678240478 +[12:06:26.969] iteration 88200 [34.48 sec]: learning rate : 0.000031 loss : 0.350534 +[12:07:14.412] iteration 88300 [81.92 sec]: learning rate : 0.000031 loss : 0.404240 +[12:08:01.944] iteration 88400 [129.45 sec]: learning rate : 0.000031 loss : 0.391863 +[12:08:49.381] iteration 88500 [176.89 sec]: learning rate : 0.000031 loss : 0.363382 +[12:09:37.463] iteration 88600 [224.97 sec]: learning rate : 0.000031 loss : 0.443866 +[12:10:25.537] iteration 88700 [273.05 sec]: learning rate : 0.000031 loss : 0.286200 +[12:10:27.457] Epoch 153 Evaluation: +[12:11:56.590] average MSE: 0.1050290098014774 average PSNR: 22.82539873758182 average SSIM: 0.6056675780591855 +[12:12:42.295] iteration 88800 [45.68 sec]: learning rate : 0.000031 loss : 0.523326 +[12:13:29.809] iteration 88900 [93.20 sec]: learning rate : 0.000031 loss : 0.381612 +[12:14:17.256] iteration 89000 [140.64 sec]: learning rate : 0.000031 loss : 0.475257 +[12:15:04.674] iteration 89100 [188.07 sec]: learning rate : 0.000031 loss : 0.415026 +[12:15:52.692] iteration 89200 [236.08 sec]: learning rate : 0.000031 loss : 0.412269 +[12:16:30.668] Epoch 154 Evaluation: +[12:17:58.224] average MSE: 0.13101072300860672 average PSNR: 21.87489446595801 average SSIM: 0.5859965026818958 +[12:18:08.051] iteration 89300 [9.80 sec]: learning rate : 0.000031 loss : 0.388317 +[12:18:56.517] iteration 89400 [58.27 sec]: learning rate : 0.000031 loss : 0.361679 +[12:19:44.047] iteration 89500 [105.80 sec]: learning rate : 0.000031 loss : 0.369902 +[12:20:31.453] iteration 89600 [153.21 sec]: learning rate : 0.000031 loss : 0.479322 +[12:21:19.040] iteration 89700 [200.79 sec]: learning rate : 0.000031 loss : 0.528813 +[12:22:06.988] iteration 89800 [248.74 sec]: learning rate : 0.000031 loss : 0.555313 +[12:22:33.543] Epoch 155 Evaluation: +[12:24:00.153] average MSE: 0.14267772984492458 average PSNR: 21.50055033097998 average SSIM: 0.5769744193457147 +[12:24:21.211] iteration 89900 [21.03 sec]: learning rate : 0.000031 loss : 0.595870 +[12:25:08.727] iteration 90000 [68.55 sec]: learning rate : 0.000031 loss : 0.316315 +[12:25:56.137] iteration 90100 [115.96 sec]: learning rate : 0.000031 loss : 0.408631 +[12:26:44.087] iteration 90200 [163.91 sec]: learning rate : 0.000031 loss : 0.444332 +[12:27:32.395] iteration 90300 [212.22 sec]: learning rate : 0.000031 loss : 0.427621 +[12:28:19.857] iteration 90400 [259.68 sec]: learning rate : 0.000031 loss : 0.469218 +[12:28:35.032] Epoch 156 Evaluation: +[12:30:01.353] average MSE: 0.14973900191146022 average PSNR: 21.29091030362212 average SSIM: 0.570390011968083 +[12:30:33.780] iteration 90500 [32.40 sec]: learning rate : 0.000031 loss : 0.372674 +[12:31:21.376] iteration 90600 [80.00 sec]: learning rate : 0.000031 loss : 0.282634 +[12:32:08.836] iteration 90700 [127.46 sec]: learning rate : 0.000031 loss : 0.389942 +[12:32:56.786] iteration 90800 [175.41 sec]: learning rate : 0.000031 loss : 0.359076 +[12:33:44.934] iteration 90900 [223.56 sec]: learning rate : 0.000031 loss : 0.438976 +[12:34:32.350] iteration 91000 [270.97 sec]: learning rate : 0.000031 loss : 0.398755 +[12:34:36.149] Epoch 157 Evaluation: +[12:36:06.517] average MSE: 0.1273452535339835 average PSNR: 21.995252750901603 average SSIM: 0.5857547332175992 +[12:36:50.997] iteration 91100 [44.46 sec]: learning rate : 0.000031 loss : 0.355468 +[12:37:38.662] iteration 91200 [92.12 sec]: learning rate : 0.000031 loss : 0.347249 +[12:38:26.677] iteration 91300 [140.13 sec]: learning rate : 0.000031 loss : 0.336333 +[12:39:14.165] iteration 91400 [187.62 sec]: learning rate : 0.000031 loss : 0.361062 +[12:40:01.662] iteration 91500 [235.12 sec]: learning rate : 0.000031 loss : 0.299869 +[12:40:41.499] Epoch 158 Evaluation: +[12:42:11.750] average MSE: 0.12205958727512321 average PSNR: 22.18222584029749 average SSIM: 0.5910557601575326 +[12:42:19.545] iteration 91600 [7.77 sec]: learning rate : 0.000031 loss : 0.425525 +[12:43:07.111] iteration 91700 [55.34 sec]: learning rate : 0.000031 loss : 0.372683 +[12:43:55.042] iteration 91800 [103.27 sec]: learning rate : 0.000031 loss : 0.524806 +[12:44:42.660] iteration 91900 [150.89 sec]: learning rate : 0.000031 loss : 0.371647 +[12:45:30.743] iteration 92000 [198.97 sec]: learning rate : 0.000031 loss : 0.271521 +[12:46:18.156] iteration 92100 [246.38 sec]: learning rate : 0.000031 loss : 0.479779 +[12:46:46.728] Epoch 159 Evaluation: +[12:48:15.827] average MSE: 0.11992272812459473 average PSNR: 22.26116025552813 average SSIM: 0.5920239941827875 +[12:48:35.226] iteration 92200 [19.38 sec]: learning rate : 0.000031 loss : 0.308753 +[12:49:22.784] iteration 92300 [66.93 sec]: learning rate : 0.000031 loss : 0.455237 +[12:50:10.137] iteration 92400 [114.29 sec]: learning rate : 0.000031 loss : 0.395016 +[12:50:57.573] iteration 92500 [161.72 sec]: learning rate : 0.000031 loss : 0.424823 +[12:51:45.067] iteration 92600 [209.22 sec]: learning rate : 0.000031 loss : 0.397613 +[12:52:32.429] iteration 92700 [256.58 sec]: learning rate : 0.000031 loss : 0.383094 +[12:52:49.507] Epoch 160 Evaluation: +[12:54:16.830] average MSE: 0.11154836494724282 average PSNR: 22.573985189700426 average SSIM: 0.5997634169958252 +[12:54:47.505] iteration 92800 [30.65 sec]: learning rate : 0.000031 loss : 0.426184 +[12:55:34.863] iteration 92900 [78.01 sec]: learning rate : 0.000031 loss : 0.466720 +[12:56:22.560] iteration 93000 [125.71 sec]: learning rate : 0.000031 loss : 0.415257 +[12:57:10.111] iteration 93100 [173.26 sec]: learning rate : 0.000031 loss : 0.493633 +[12:57:57.600] iteration 93200 [220.75 sec]: learning rate : 0.000031 loss : 0.392987 +[12:58:45.092] iteration 93300 [268.24 sec]: learning rate : 0.000031 loss : 0.424191 +[12:58:50.802] Epoch 161 Evaluation: +[13:00:16.967] average MSE: 0.10997041735816912 average PSNR: 22.6300837694302 average SSIM: 0.6019258189858394 +[13:00:59.079] iteration 93400 [42.09 sec]: learning rate : 0.000031 loss : 0.443839 +[13:01:46.490] iteration 93500 [89.50 sec]: learning rate : 0.000031 loss : 0.284310 +[13:02:33.996] iteration 93600 [137.01 sec]: learning rate : 0.000031 loss : 0.433352 +[13:03:22.254] iteration 93700 [185.26 sec]: learning rate : 0.000031 loss : 0.314657 +[13:04:09.687] iteration 93800 [232.70 sec]: learning rate : 0.000031 loss : 0.389547 +[13:04:51.498] Epoch 162 Evaluation: +[13:06:17.218] average MSE: 0.1272321754764066 average PSNR: 21.995015501131252 average SSIM: 0.5844231560511031 +[13:06:23.117] iteration 93900 [5.87 sec]: learning rate : 0.000031 loss : 0.364317 +[13:07:10.478] iteration 94000 [53.24 sec]: learning rate : 0.000031 loss : 0.362846 +[13:07:58.010] iteration 94100 [100.77 sec]: learning rate : 0.000031 loss : 0.350363 +[13:08:45.548] iteration 94200 [148.31 sec]: learning rate : 0.000031 loss : 0.354571 +[13:09:33.197] iteration 94300 [195.96 sec]: learning rate : 0.000031 loss : 0.446870 +[13:10:20.974] iteration 94400 [243.73 sec]: learning rate : 0.000031 loss : 0.481888 +[13:10:51.800] Epoch 163 Evaluation: +[13:12:22.930] average MSE: 0.13031757174122882 average PSNR: 21.897845730392465 average SSIM: 0.5835307898764993 +[13:12:40.455] iteration 94500 [17.50 sec]: learning rate : 0.000031 loss : 0.554435 +[13:13:28.017] iteration 94600 [65.06 sec]: learning rate : 0.000031 loss : 0.394859 +[13:14:15.639] iteration 94700 [112.69 sec]: learning rate : 0.000031 loss : 0.319171 +[13:15:03.247] iteration 94800 [160.30 sec]: learning rate : 0.000031 loss : 0.359344 +[13:15:50.902] iteration 94900 [207.95 sec]: learning rate : 0.000031 loss : 0.441040 +[13:16:38.532] iteration 95000 [255.58 sec]: learning rate : 0.000031 loss : 0.326441 +[13:16:57.542] Epoch 164 Evaluation: +[13:18:27.856] average MSE: 0.13663780660090108 average PSNR: 21.695700675647828 average SSIM: 0.5802091292914402 +[13:18:56.496] iteration 95100 [28.62 sec]: learning rate : 0.000031 loss : 0.382184 +[13:19:44.057] iteration 95200 [76.18 sec]: learning rate : 0.000031 loss : 0.405571 +[13:20:31.540] iteration 95300 [123.66 sec]: learning rate : 0.000031 loss : 0.419568 +[13:21:19.351] iteration 95400 [171.47 sec]: learning rate : 0.000031 loss : 0.725294 +[13:22:06.909] iteration 95500 [219.03 sec]: learning rate : 0.000031 loss : 0.428053 +[13:22:54.401] iteration 95600 [266.52 sec]: learning rate : 0.000031 loss : 0.339612 +[13:23:02.012] Epoch 165 Evaluation: +[13:24:30.228] average MSE: 0.10395904848427513 average PSNR: 22.864174807061406 average SSIM: 0.6097137933421534 +[13:25:10.211] iteration 95700 [39.96 sec]: learning rate : 0.000031 loss : 0.292470 +[13:25:58.238] iteration 95800 [87.99 sec]: learning rate : 0.000031 loss : 0.442926 +[13:26:45.772] iteration 95900 [135.52 sec]: learning rate : 0.000031 loss : 0.533688 +[13:27:33.146] iteration 96000 [182.89 sec]: learning rate : 0.000031 loss : 0.398060 +[13:28:20.719] iteration 96100 [230.47 sec]: learning rate : 0.000031 loss : 0.393613 +[13:29:04.392] Epoch 166 Evaluation: +[13:30:32.510] average MSE: 0.11510939633477703 average PSNR: 22.429912496205226 average SSIM: 0.5949622779501966 +[13:30:36.527] iteration 96200 [3.99 sec]: learning rate : 0.000031 loss : 0.549023 +[13:31:24.011] iteration 96300 [51.48 sec]: learning rate : 0.000031 loss : 0.385094 +[13:32:11.431] iteration 96400 [98.90 sec]: learning rate : 0.000031 loss : 0.379982 +[13:32:58.868] iteration 96500 [146.33 sec]: learning rate : 0.000031 loss : 0.416507 +[13:33:46.741] iteration 96600 [194.21 sec]: learning rate : 0.000031 loss : 0.355751 +[13:34:34.248] iteration 96700 [241.71 sec]: learning rate : 0.000031 loss : 0.460604 +[13:35:06.503] Epoch 167 Evaluation: +[13:36:35.047] average MSE: 0.10886781013836404 average PSNR: 22.66519336801812 average SSIM: 0.6038997676767537 +[13:36:50.473] iteration 96800 [15.40 sec]: learning rate : 0.000031 loss : 0.376432 +[13:37:38.191] iteration 96900 [63.12 sec]: learning rate : 0.000031 loss : 0.404906 +[13:38:25.758] iteration 97000 [110.69 sec]: learning rate : 0.000031 loss : 0.456628 +[13:39:13.795] iteration 97100 [158.72 sec]: learning rate : 0.000031 loss : 0.469081 +[13:40:01.242] iteration 97200 [206.17 sec]: learning rate : 0.000031 loss : 0.459882 +[13:40:48.816] iteration 97300 [253.74 sec]: learning rate : 0.000031 loss : 0.296579 +[13:41:09.744] Epoch 168 Evaluation: +[13:42:35.539] average MSE: 0.11766270224651902 average PSNR: 22.340197630621635 average SSIM: 0.597162103298995 +[13:43:02.255] iteration 97400 [26.69 sec]: learning rate : 0.000031 loss : 0.461781 +[13:43:49.783] iteration 97500 [74.22 sec]: learning rate : 0.000031 loss : 0.379591 +[13:44:37.100] iteration 97600 [121.54 sec]: learning rate : 0.000031 loss : 0.308600 +[13:45:24.521] iteration 97700 [168.96 sec]: learning rate : 0.000031 loss : 0.398670 +[13:46:12.037] iteration 97800 [216.47 sec]: learning rate : 0.000031 loss : 0.447936 +[13:46:59.381] iteration 97900 [263.82 sec]: learning rate : 0.000031 loss : 0.445464 +[13:47:08.857] Epoch 169 Evaluation: +[13:48:35.183] average MSE: 0.11040024933372834 average PSNR: 22.612819153243468 average SSIM: 0.6031891569152102 +[13:49:13.625] iteration 98000 [38.42 sec]: learning rate : 0.000031 loss : 0.352502 +[13:50:01.130] iteration 98100 [85.92 sec]: learning rate : 0.000031 loss : 0.403979 +[13:50:48.658] iteration 98200 [133.45 sec]: learning rate : 0.000031 loss : 0.363485 +[13:51:36.102] iteration 98300 [180.89 sec]: learning rate : 0.000031 loss : 0.385335 +[13:52:23.500] iteration 98400 [228.29 sec]: learning rate : 0.000031 loss : 0.459270 +[13:53:09.091] Epoch 170 Evaluation: +[13:54:35.714] average MSE: 0.14075585797814558 average PSNR: 21.561335402876466 average SSIM: 0.5738438245807218 +[13:54:37.819] iteration 98500 [2.08 sec]: learning rate : 0.000031 loss : 0.407908 +[13:55:25.291] iteration 98600 [49.55 sec]: learning rate : 0.000031 loss : 0.528537 +[13:56:13.343] iteration 98700 [97.61 sec]: learning rate : 0.000031 loss : 0.482105 +[13:57:01.188] iteration 98800 [145.45 sec]: learning rate : 0.000031 loss : 0.463238 +[13:57:48.648] iteration 98900 [192.91 sec]: learning rate : 0.000031 loss : 0.336735 +[13:58:36.023] iteration 99000 [240.29 sec]: learning rate : 0.000031 loss : 0.439658 +[13:59:10.229] Epoch 171 Evaluation: +[14:00:35.888] average MSE: 0.12470778827000857 average PSNR: 22.091734939029585 average SSIM: 0.5929546376367357 +[14:00:49.396] iteration 99100 [13.49 sec]: learning rate : 0.000031 loss : 0.341618 +[14:01:37.081] iteration 99200 [61.17 sec]: learning rate : 0.000031 loss : 0.403433 +[14:02:24.573] iteration 99300 [108.66 sec]: learning rate : 0.000031 loss : 0.393190 +[14:03:12.768] iteration 99400 [156.86 sec]: learning rate : 0.000031 loss : 0.487900 +[14:04:00.286] iteration 99500 [204.38 sec]: learning rate : 0.000031 loss : 0.387135 +[14:04:47.908] iteration 99600 [252.00 sec]: learning rate : 0.000031 loss : 0.424540 +[14:05:10.733] Epoch 172 Evaluation: +[14:06:40.613] average MSE: 0.1155188353451852 average PSNR: 22.423484082855335 average SSIM: 0.6033781205947474 +[14:07:05.667] iteration 99700 [25.03 sec]: learning rate : 0.000031 loss : 0.337835 +[14:07:53.270] iteration 99800 [72.64 sec]: learning rate : 0.000031 loss : 0.377346 +[14:08:40.870] iteration 99900 [120.24 sec]: learning rate : 0.000031 loss : 0.356040 +[14:09:28.308] iteration 100000 [167.67 sec]: learning rate : 0.000008 loss : 0.386255 +[14:09:28.465] save model to model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/iter_100000.pth +[14:09:28.954] Epoch 173 Evaluation: +[14:10:55.674] average MSE: 0.11338342732751923 average PSNR: 22.50435805726616 average SSIM: 0.6027459139405325 +[14:10:55.979] save model to model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/iter_100000.pth +===> Evaluate Metric <=== +Results +------------------------------------ +ColdDiffusion NMSE: 0.8723 ± 0.0475 +ColdDiffusion PSNR: 34.5184 ± 0.4337 +ColdDiffusion SSIM: 0.8950 ± 0.0078 +------------------------------------ +All NMSE: 0.8706 ± 0.1335 +All PSNR: 33.4786 ± 0.7384 +All SSIM: 0.8786 ± 0.0129 +------------------------------------ +Save Path: /home/v-qichen3/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/result_case/ \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/log/events.out.tfevents.1752550781.GCRSANDBOX133 b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/log/events.out.tfevents.1752550781.GCRSANDBOX133 new file mode 100644 index 0000000000000000000000000000000000000000..abf3d963cc06369bfe73be4b9cc99824e3426bcc --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t10_new_kspace_time/log/events.out.tfevents.1752550781.GCRSANDBOX133 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15c12a228a954b3d78f770aa20f6d56fbfea279585a5631b6aae3b8abaa3f2a6 +size 40 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/best_checkpoint.pth b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/best_checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..80b0fe16a1fdab5ba4d97ed2a3d602180742127d --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/best_checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4ee119fc27d36ad9cc4a8b95005d673c8744c8ac87530e7a40e4226f371d3cc6 +size 56614874 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/log.txt b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..7fe9657e4ede487c7333a82d16e5ebf083d9fd4f --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/log.txt @@ -0,0 +1,1368 @@ +[20:34:42.994] Namespace(root_path='/home/v-qichen3/MRI_recon/data/m4raw', MRIDOWN='4X', low_field_SNR=0, phase='train', gpu='2', exp='FSMNet_m4raw_4x_lr5e-4', max_iterations=100000, batch_size=4, base_lr=0.0005, seed=1337, resume=None, relation_consistency='False', clip_grad='True', norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='random', CENTER_FRACTIONS=[0.08], ACCELERATIONS=[4], num_timesteps=20, image_size=240, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[20:35:05.778] Namespace(root_path='/home/v-qichen3/MRI_recon/data/m4raw', MRIDOWN='4X', low_field_SNR=0, phase='train', gpu='2', exp='FSMNet_m4raw_4x_lr5e-4', max_iterations=100000, batch_size=4, base_lr=0.0005, seed=1337, resume=None, relation_consistency='False', clip_grad='True', norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='random', CENTER_FRACTIONS=[0.08], ACCELERATIONS=[4], num_timesteps=20, image_size=240, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[20:41:09.315] iteration 100 [48.68 sec]: learning rate : 0.000500 loss : 0.548355 +[20:41:56.681] iteration 200 [96.04 sec]: learning rate : 0.000500 loss : 0.489771 +[20:42:44.209] iteration 300 [143.57 sec]: learning rate : 0.000500 loss : 0.430007 +[20:43:32.279] iteration 400 [191.64 sec]: learning rate : 0.000500 loss : 0.619990 +[20:44:19.763] iteration 500 [239.13 sec]: learning rate : 0.000500 loss : 0.540711 +[20:44:55.901] Epoch 0 Evaluation: +[20:47:48.489] average MSE: 0.05242393655223164 average PSNR: 25.84477281972974 average SSIM: 0.6945999729710618 +[20:48:00.120] iteration 600 [11.61 sec]: learning rate : 0.000500 loss : 0.436492 +[20:48:48.333] iteration 700 [59.83 sec]: learning rate : 0.000500 loss : 0.363718 +[20:49:35.886] iteration 800 [107.37 sec]: learning rate : 0.000500 loss : 0.419218 +[20:50:23.802] iteration 900 [155.29 sec]: learning rate : 0.000500 loss : 0.368809 +[20:51:11.478] iteration 1000 [202.97 sec]: learning rate : 0.000500 loss : 0.325579 +[20:51:59.652] iteration 1100 [251.14 sec]: learning rate : 0.000500 loss : 0.456892 +[20:52:24.463] Epoch 1 Evaluation: +[20:55:16.742] average MSE: 0.048761416617853 average PSNR: 26.151313863412845 average SSIM: 0.6962001752206908 +[20:55:39.727] iteration 1200 [22.96 sec]: learning rate : 0.000500 loss : 0.396915 +[20:56:27.390] iteration 1300 [70.63 sec]: learning rate : 0.000500 loss : 0.362194 +[20:57:14.861] iteration 1400 [118.10 sec]: learning rate : 0.000500 loss : 0.405825 +[20:58:02.986] iteration 1500 [166.22 sec]: learning rate : 0.000500 loss : 0.383045 +[20:58:50.523] iteration 1600 [213.76 sec]: learning rate : 0.000500 loss : 0.447154 +[20:59:38.963] iteration 1700 [262.20 sec]: learning rate : 0.000500 loss : 0.392497 +[20:59:52.338] Epoch 2 Evaluation: +[21:02:50.446] average MSE: 0.04995786560906825 average PSNR: 26.04423642144626 average SSIM: 0.6894570680229976 +[21:03:25.101] iteration 1800 [34.64 sec]: learning rate : 0.000500 loss : 0.389355 +[21:04:13.330] iteration 1900 [82.87 sec]: learning rate : 0.000500 loss : 0.407575 +[21:05:01.138] iteration 2000 [130.67 sec]: learning rate : 0.000500 loss : 0.419787 +[21:05:49.417] iteration 2100 [178.95 sec]: learning rate : 0.000500 loss : 0.357548 +[21:06:37.135] iteration 2200 [226.67 sec]: learning rate : 0.000500 loss : 0.282558 +[21:07:24.872] iteration 2300 [274.40 sec]: learning rate : 0.000500 loss : 0.320719 +[21:07:26.788] Epoch 3 Evaluation: +[21:10:22.773] average MSE: 0.05077012963722015 average PSNR: 25.973727143375427 average SSIM: 0.6835036182172038 +[21:11:09.320] iteration 2400 [46.52 sec]: learning rate : 0.000500 loss : 0.401752 +[21:11:56.806] iteration 2500 [94.01 sec]: learning rate : 0.000500 loss : 0.332018 +[21:12:44.381] iteration 2600 [141.59 sec]: learning rate : 0.000500 loss : 0.405515 +[21:13:32.048] iteration 2700 [189.25 sec]: learning rate : 0.000500 loss : 0.417525 +[21:14:19.682] iteration 2800 [236.89 sec]: learning rate : 0.000500 loss : 0.352484 +[21:14:57.849] Epoch 4 Evaluation: +[21:17:52.865] average MSE: 0.05224657240798211 average PSNR: 25.849406654442948 average SSIM: 0.6772241056899303 +[21:18:02.575] iteration 2900 [9.69 sec]: learning rate : 0.000500 loss : 0.273317 +[21:18:50.206] iteration 3000 [57.32 sec]: learning rate : 0.000500 loss : 0.373011 +[21:19:37.955] iteration 3100 [105.07 sec]: learning rate : 0.000500 loss : 0.372658 +[21:20:25.668] iteration 3200 [152.78 sec]: learning rate : 0.000500 loss : 0.393182 +[21:21:13.368] iteration 3300 [200.48 sec]: learning rate : 0.000500 loss : 0.328634 +[21:22:01.619] iteration 3400 [248.73 sec]: learning rate : 0.000500 loss : 0.396799 +[21:22:28.616] Epoch 5 Evaluation: +[21:25:26.547] average MSE: 0.05316328318796439 average PSNR: 25.772847095310972 average SSIM: 0.674448397367536 +[21:25:47.783] iteration 3500 [21.21 sec]: learning rate : 0.000500 loss : 0.411972 +[21:26:35.579] iteration 3600 [69.01 sec]: learning rate : 0.000500 loss : 0.363932 +[21:27:23.151] iteration 3700 [116.58 sec]: learning rate : 0.000500 loss : 0.388520 +[21:28:11.320] iteration 3800 [164.75 sec]: learning rate : 0.000500 loss : 0.380213 +[21:28:59.218] iteration 3900 [212.65 sec]: learning rate : 0.000500 loss : 0.339040 +[21:29:46.774] iteration 4000 [260.20 sec]: learning rate : 0.000500 loss : 0.367057 +[21:30:02.004] Epoch 6 Evaluation: +[21:32:54.547] average MSE: 0.0449220221526813 average PSNR: 26.513088055869037 average SSIM: 0.7164621006719605 +[21:33:27.783] iteration 4100 [33.21 sec]: learning rate : 0.000500 loss : 0.387059 +[21:34:15.416] iteration 4200 [80.85 sec]: learning rate : 0.000500 loss : 0.364654 +[21:35:04.129] iteration 4300 [129.56 sec]: learning rate : 0.000500 loss : 0.339422 +[21:35:51.729] iteration 4400 [177.16 sec]: learning rate : 0.000500 loss : 0.396217 +[21:36:39.501] iteration 4500 [224.93 sec]: learning rate : 0.000500 loss : 0.448229 +[21:37:27.234] iteration 4600 [272.66 sec]: learning rate : 0.000500 loss : 0.404937 +[21:37:31.044] Epoch 7 Evaluation: +[21:40:23.534] average MSE: 0.05255968062652358 average PSNR: 25.82532118501119 average SSIM: 0.663707161848597 +[21:41:07.823] iteration 4700 [44.27 sec]: learning rate : 0.000500 loss : 0.292477 +[21:41:55.453] iteration 4800 [91.90 sec]: learning rate : 0.000500 loss : 0.401433 +[21:42:43.085] iteration 4900 [139.53 sec]: learning rate : 0.000500 loss : 0.359161 +[21:43:30.576] iteration 5000 [187.02 sec]: learning rate : 0.000500 loss : 0.352800 +[21:44:19.148] iteration 5100 [235.59 sec]: learning rate : 0.000500 loss : 0.288492 +[21:44:59.096] Epoch 8 Evaluation: +[21:47:52.430] average MSE: 0.04845840906277631 average PSNR: 26.177294712304885 average SSIM: 0.6991141006898869 +[21:48:00.365] iteration 5200 [7.91 sec]: learning rate : 0.000500 loss : 0.394695 +[21:48:47.805] iteration 5300 [55.35 sec]: learning rate : 0.000500 loss : 0.393788 +[21:49:35.338] iteration 5400 [102.88 sec]: learning rate : 0.000500 loss : 0.353882 +[21:50:23.184] iteration 5500 [150.73 sec]: learning rate : 0.000500 loss : 0.274528 +[21:51:10.724] iteration 5600 [198.27 sec]: learning rate : 0.000500 loss : 0.436253 +[21:51:58.281] iteration 5700 [245.83 sec]: learning rate : 0.000500 loss : 0.290077 +[21:52:27.050] Epoch 9 Evaluation: +[21:55:20.749] average MSE: 0.04426152379067602 average PSNR: 26.568468390819156 average SSIM: 0.7115007211430306 +[21:55:40.362] iteration 5800 [19.59 sec]: learning rate : 0.000500 loss : 0.409161 +[21:56:27.989] iteration 5900 [67.22 sec]: learning rate : 0.000500 loss : 0.401935 +[21:57:15.642] iteration 6000 [114.87 sec]: learning rate : 0.000500 loss : 0.363221 +[21:58:03.101] iteration 6100 [162.33 sec]: learning rate : 0.000500 loss : 0.363472 +[21:58:51.260] iteration 6200 [210.49 sec]: learning rate : 0.000500 loss : 0.352662 +[21:59:38.840] iteration 6300 [258.07 sec]: learning rate : 0.000500 loss : 0.427832 +[21:59:55.947] Epoch 10 Evaluation: +[22:02:48.523] average MSE: 0.04447947486621098 average PSNR: 26.549976114345174 average SSIM: 0.706639325461073 +[22:03:19.172] iteration 6400 [30.63 sec]: learning rate : 0.000500 loss : 0.419008 +[22:04:06.881] iteration 6500 [78.34 sec]: learning rate : 0.000500 loss : 0.364277 +[22:04:54.969] iteration 6600 [126.42 sec]: learning rate : 0.000500 loss : 0.362495 +[22:05:42.445] iteration 6700 [173.90 sec]: learning rate : 0.000500 loss : 0.443638 +[22:06:30.051] iteration 6800 [221.51 sec]: learning rate : 0.000500 loss : 0.476716 +[22:07:18.292] iteration 6900 [269.75 sec]: learning rate : 0.000500 loss : 0.350051 +[22:07:24.101] Epoch 11 Evaluation: +[22:10:22.782] average MSE: 0.05537294228812629 average PSNR: 25.60194314376832 average SSIM: 0.6414937138401885 +[22:11:05.825] iteration 7000 [43.02 sec]: learning rate : 0.000500 loss : 0.391268 +[22:11:53.444] iteration 7100 [90.64 sec]: learning rate : 0.000500 loss : 0.395224 +[22:12:41.474] iteration 7200 [138.67 sec]: learning rate : 0.000500 loss : 0.363412 +[22:13:29.016] iteration 7300 [186.21 sec]: learning rate : 0.000500 loss : 0.313532 +[22:14:16.593] iteration 7400 [233.79 sec]: learning rate : 0.000500 loss : 0.322512 +[22:14:58.340] Epoch 12 Evaluation: +[22:17:50.858] average MSE: 0.049640601171190735 average PSNR: 26.071823283480395 average SSIM: 0.6717997243194492 +[22:17:57.167] iteration 7500 [6.29 sec]: learning rate : 0.000500 loss : 0.428802 +[22:18:44.722] iteration 7600 [53.84 sec]: learning rate : 0.000500 loss : 0.368502 +[22:19:32.228] iteration 7700 [101.37 sec]: learning rate : 0.000500 loss : 0.273425 +[22:20:19.979] iteration 7800 [149.10 sec]: learning rate : 0.000500 loss : 0.440077 +[22:21:07.557] iteration 7900 [196.68 sec]: learning rate : 0.000500 loss : 0.291747 +[22:21:55.155] iteration 8000 [244.27 sec]: learning rate : 0.000500 loss : 0.417772 +[22:22:25.625] Epoch 13 Evaluation: +[22:25:21.262] average MSE: 0.04708113507245313 average PSNR: 26.299297029898298 average SSIM: 0.6907954612810813 +[22:25:38.514] iteration 8100 [17.23 sec]: learning rate : 0.000500 loss : 0.357362 +[22:26:26.070] iteration 8200 [64.78 sec]: learning rate : 0.000500 loss : 0.348273 +[22:27:13.703] iteration 8300 [112.42 sec]: learning rate : 0.000500 loss : 0.312312 +[22:28:01.396] iteration 8400 [160.11 sec]: learning rate : 0.000500 loss : 0.371545 +[22:28:49.217] iteration 8500 [207.93 sec]: learning rate : 0.000500 loss : 2.304659 +[22:29:37.392] iteration 8600 [256.11 sec]: learning rate : 0.000500 loss : 0.349166 +[22:29:56.458] Epoch 14 Evaluation: +[22:32:53.736] average MSE: 0.04546247580878657 average PSNR: 26.457473643810044 average SSIM: 0.694555812531186 +[22:33:22.385] iteration 8700 [28.63 sec]: learning rate : 0.000500 loss : 0.352370 +[22:34:10.140] iteration 8800 [76.38 sec]: learning rate : 0.000500 loss : 0.309552 +[22:34:58.446] iteration 8900 [124.71 sec]: learning rate : 0.000500 loss : 0.436108 +[22:35:46.697] iteration 9000 [172.94 sec]: learning rate : 0.000500 loss : 0.331876 +[22:36:34.302] iteration 9100 [220.54 sec]: learning rate : 0.000500 loss : 0.707187 +[22:37:21.775] iteration 9200 [268.02 sec]: learning rate : 0.000500 loss : 0.343161 +[22:37:29.370] Epoch 15 Evaluation: +[22:40:22.413] average MSE: 0.04262396551058519 average PSNR: 26.73444136345808 average SSIM: 0.7257174268409493 +[22:41:02.600] iteration 9300 [40.16 sec]: learning rate : 0.000500 loss : 0.386341 +[22:41:50.610] iteration 9400 [88.17 sec]: learning rate : 0.000500 loss : 0.370543 +[22:42:38.182] iteration 9500 [135.75 sec]: learning rate : 0.000500 loss : 0.330911 +[22:43:25.760] iteration 9600 [183.32 sec]: learning rate : 0.000500 loss : 0.301209 +[22:44:13.291] iteration 9700 [230.85 sec]: learning rate : 0.000500 loss : 0.363492 +[22:44:57.071] Epoch 16 Evaluation: +[22:47:50.452] average MSE: 0.04576363210445284 average PSNR: 26.426143919230288 average SSIM: 0.7005409049354703 +[22:47:54.458] iteration 9800 [3.98 sec]: learning rate : 0.000500 loss : 1.990100 +[22:48:42.254] iteration 9900 [51.78 sec]: learning rate : 0.000500 loss : 0.277933 +[22:49:29.707] iteration 10000 [99.23 sec]: learning rate : 0.000500 loss : 0.349416 +[22:50:17.337] iteration 10100 [146.86 sec]: learning rate : 0.000500 loss : 0.380274 +[22:51:04.922] iteration 10200 [194.45 sec]: learning rate : 0.000500 loss : 0.281496 +[22:51:52.875] iteration 10300 [242.40 sec]: learning rate : 0.000500 loss : 0.412768 +[22:52:25.387] Epoch 17 Evaluation: +[22:55:27.283] average MSE: 0.03895129854997659 average PSNR: 27.150153462338533 average SSIM: 0.7407313802627449 +[22:55:43.019] iteration 10400 [15.74 sec]: learning rate : 0.000500 loss : 0.334070 +[22:56:30.636] iteration 10500 [63.33 sec]: learning rate : 0.000500 loss : 0.359422 +[22:57:18.505] iteration 10600 [111.20 sec]: learning rate : 0.000500 loss : 0.394109 +[22:58:06.065] iteration 10700 [158.76 sec]: learning rate : 0.000500 loss : 0.332064 +[22:58:53.559] iteration 10800 [206.25 sec]: learning rate : 0.000500 loss : 0.351856 +[22:59:41.628] iteration 10900 [254.32 sec]: learning rate : 0.000500 loss : 0.319206 +[23:00:02.524] Epoch 18 Evaluation: +[23:02:55.785] average MSE: 0.05077366853444975 average PSNR: 25.97712469934738 average SSIM: 0.6785511248902739 +[23:03:22.716] iteration 11000 [26.90 sec]: learning rate : 0.000500 loss : 0.425905 +[23:04:10.150] iteration 11100 [74.34 sec]: learning rate : 0.000500 loss : 0.333244 +[23:04:58.066] iteration 11200 [122.26 sec]: learning rate : 0.000500 loss : 0.321599 +[23:05:46.082] iteration 11300 [170.27 sec]: learning rate : 0.000500 loss : 0.547613 +[23:06:33.596] iteration 11400 [217.78 sec]: learning rate : 0.000500 loss : 0.274102 +[23:07:21.214] iteration 11500 [265.40 sec]: learning rate : 0.000500 loss : 0.335453 +[23:07:30.730] Epoch 19 Evaluation: +[23:10:24.641] average MSE: 0.046033544476599804 average PSNR: 26.40087948103995 average SSIM: 0.6883712807609775 +[23:11:02.865] iteration 11600 [38.20 sec]: learning rate : 0.000500 loss : 0.371875 +[23:11:51.230] iteration 11700 [86.57 sec]: learning rate : 0.000500 loss : 0.367711 +[23:12:38.994] iteration 11800 [134.33 sec]: learning rate : 0.000500 loss : 0.327574 +[23:13:26.651] iteration 11900 [181.99 sec]: learning rate : 0.000500 loss : 0.275928 +[23:14:15.074] iteration 12000 [230.41 sec]: learning rate : 0.000500 loss : 0.262840 +[23:15:00.772] Epoch 20 Evaluation: +[23:17:52.920] average MSE: 0.04615437403951646 average PSNR: 26.39139812497202 average SSIM: 0.6951646817037717 +[23:17:55.014] iteration 12100 [2.07 sec]: learning rate : 0.000500 loss : 0.288316 +[23:18:42.469] iteration 12200 [49.52 sec]: learning rate : 0.000500 loss : 0.307961 +[23:19:30.319] iteration 12300 [97.38 sec]: learning rate : 0.000500 loss : 0.289404 +[23:20:17.911] iteration 12400 [144.97 sec]: learning rate : 0.000500 loss : 0.465381 +[23:21:05.411] iteration 12500 [192.47 sec]: learning rate : 0.000500 loss : 0.357707 +[23:21:53.064] iteration 12600 [240.12 sec]: learning rate : 0.000500 loss : 0.260140 +[23:22:27.316] Epoch 21 Evaluation: +[23:25:19.442] average MSE: 0.042448912957821106 average PSNR: 26.74953759196258 average SSIM: 0.7099657611168393 +[23:25:33.103] iteration 12700 [13.64 sec]: learning rate : 0.000500 loss : 0.303646 +[23:26:20.600] iteration 12800 [61.13 sec]: learning rate : 0.000500 loss : 0.381682 +[23:27:08.193] iteration 12900 [108.73 sec]: learning rate : 0.000500 loss : 0.287078 +[23:27:55.660] iteration 13000 [156.19 sec]: learning rate : 0.000500 loss : 0.377564 +[23:28:43.300] iteration 13100 [203.83 sec]: learning rate : 0.000500 loss : 0.350222 +[23:29:30.912] iteration 13200 [251.45 sec]: learning rate : 0.000500 loss : 0.303255 +[23:29:54.133] Epoch 22 Evaluation: +[23:32:51.068] average MSE: 0.04590690932425879 average PSNR: 26.411124108686497 average SSIM: 0.6913507240940879 +[23:33:15.989] iteration 13300 [24.90 sec]: learning rate : 0.000500 loss : 0.321236 +[23:34:03.678] iteration 13400 [72.59 sec]: learning rate : 0.000500 loss : 0.350965 +[23:34:51.269] iteration 13500 [120.18 sec]: learning rate : 0.000500 loss : 0.348544 +[23:35:38.869] iteration 13600 [167.80 sec]: learning rate : 0.000500 loss : 0.297837 +[23:36:27.391] iteration 13700 [216.30 sec]: learning rate : 0.000500 loss : 2.761527 +[23:37:14.954] iteration 13800 [263.86 sec]: learning rate : 0.000500 loss : 0.267808 +[23:37:26.361] Epoch 23 Evaluation: +[23:40:19.748] average MSE: 0.03958000176594672 average PSNR: 27.05598754646513 average SSIM: 0.7305760395754495 +[23:40:56.312] iteration 13900 [36.54 sec]: learning rate : 0.000500 loss : 0.291541 +[23:41:44.337] iteration 14000 [84.56 sec]: learning rate : 0.000500 loss : 0.309151 +[23:42:32.550] iteration 14100 [132.78 sec]: learning rate : 0.000500 loss : 0.402981 +[23:43:20.119] iteration 14200 [180.35 sec]: learning rate : 0.000500 loss : 0.306888 +[23:44:07.753] iteration 14300 [227.98 sec]: learning rate : 0.000500 loss : 0.361749 +[23:44:55.240] iteration 14400 [275.47 sec]: learning rate : 0.000500 loss : 0.354988 +[23:44:55.277] Epoch 24 Evaluation: +[23:47:47.580] average MSE: 0.03924386394635124 average PSNR: 27.092958089637886 average SSIM: 0.7251519003956948 +[23:48:35.842] iteration 14500 [48.24 sec]: learning rate : 0.000500 loss : 0.356907 +[23:49:23.685] iteration 14600 [96.08 sec]: learning rate : 0.000500 loss : 0.326411 +[23:50:11.190] iteration 14700 [143.59 sec]: learning rate : 0.000500 loss : 0.365109 +[23:50:58.787] iteration 14800 [191.18 sec]: learning rate : 0.000500 loss : 0.465894 +[23:51:46.406] iteration 14900 [238.80 sec]: learning rate : 0.000500 loss : 0.370995 +[23:52:22.512] Epoch 25 Evaluation: +[23:55:15.961] average MSE: 0.03985935225308438 average PSNR: 27.03137135344694 average SSIM: 0.7343327229412363 +[23:55:27.555] iteration 15000 [11.57 sec]: learning rate : 0.000500 loss : 6.066886 +[23:56:15.200] iteration 15100 [59.21 sec]: learning rate : 0.000500 loss : 0.298254 +[23:57:02.713] iteration 15200 [106.73 sec]: learning rate : 0.000500 loss : 0.434818 +[23:57:50.356] iteration 15300 [154.37 sec]: learning rate : 0.000500 loss : 0.343833 +[23:58:38.763] iteration 15400 [202.78 sec]: learning rate : 0.000500 loss : 0.397200 +[23:59:26.341] iteration 15500 [250.36 sec]: learning rate : 0.000500 loss : 0.353914 +[23:59:51.133] Epoch 26 Evaluation: +[00:02:44.223] average MSE: 0.03482254749470931 average PSNR: 27.620202233680494 average SSIM: 0.7562714207468844 +[00:03:07.264] iteration 15600 [23.02 sec]: learning rate : 0.000500 loss : 0.390343 +[00:03:54.878] iteration 15700 [70.63 sec]: learning rate : 0.000500 loss : 13.114146 +[00:04:42.784] iteration 15800 [118.54 sec]: learning rate : 0.000500 loss : 0.334049 +[00:05:30.337] iteration 15900 [166.09 sec]: learning rate : 0.000500 loss : 0.310368 +[00:06:18.336] iteration 16000 [214.09 sec]: learning rate : 0.000500 loss : 0.356434 +[00:07:05.795] iteration 16100 [261.55 sec]: learning rate : 0.000500 loss : 0.383474 +[00:07:19.202] Epoch 27 Evaluation: +[00:10:13.847] average MSE: 0.040321625621777775 average PSNR: 26.979020839669356 average SSIM: 0.7238449969321428 +[00:10:48.504] iteration 16200 [34.63 sec]: learning rate : 0.000500 loss : 0.319681 +[00:11:35.994] iteration 16300 [82.12 sec]: learning rate : 0.000500 loss : 0.394594 +[00:12:24.097] iteration 16400 [130.23 sec]: learning rate : 0.000500 loss : 0.333772 +[00:13:11.631] iteration 16500 [177.76 sec]: learning rate : 0.000500 loss : 0.306852 +[00:13:59.097] iteration 16600 [225.23 sec]: learning rate : 0.000500 loss : 0.261203 +[00:14:46.699] iteration 16700 [272.83 sec]: learning rate : 0.000500 loss : 0.327476 +[00:14:48.609] Epoch 28 Evaluation: +[00:17:46.150] average MSE: 0.03705594143552173 average PSNR: 27.341090097912865 average SSIM: 0.7387757572608024 +[00:18:32.487] iteration 16800 [46.31 sec]: learning rate : 0.000500 loss : 0.392546 +[00:19:20.001] iteration 16900 [93.83 sec]: learning rate : 0.000500 loss : 0.301086 +[00:20:07.644] iteration 17000 [141.47 sec]: learning rate : 0.000500 loss : 0.367053 +[00:20:55.768] iteration 17100 [189.60 sec]: learning rate : 0.000500 loss : 0.347916 +[00:21:43.249] iteration 17200 [237.08 sec]: learning rate : 0.000500 loss : 0.363051 +[00:22:21.313] Epoch 29 Evaluation: +[00:25:15.612] average MSE: 0.03850442789902187 average PSNR: 27.1762695323486 average SSIM: 0.7278698197384821 +[00:25:25.293] iteration 17300 [9.66 sec]: learning rate : 0.000500 loss : 0.228806 +[00:26:12.697] iteration 17400 [57.06 sec]: learning rate : 0.000500 loss : 0.494406 +[00:27:00.753] iteration 17500 [105.12 sec]: learning rate : 0.000500 loss : 0.358787 +[00:27:48.273] iteration 17600 [152.64 sec]: learning rate : 0.000500 loss : 0.242644 +[00:28:35.911] iteration 17700 [200.28 sec]: learning rate : 0.000500 loss : 0.319465 +[00:29:23.659] iteration 17800 [248.02 sec]: learning rate : 0.000500 loss : 0.287138 +[00:29:50.291] Epoch 30 Evaluation: +[00:32:47.401] average MSE: 0.03997415456418953 average PSNR: 27.016634394198118 average SSIM: 0.7291501230499602 +[00:33:08.670] iteration 17900 [21.25 sec]: learning rate : 0.000500 loss : 0.374152 +[00:33:56.164] iteration 18000 [68.74 sec]: learning rate : 0.000500 loss : 0.325746 +[00:34:44.292] iteration 18100 [116.87 sec]: learning rate : 0.000500 loss : 0.335041 +[00:35:31.862] iteration 18200 [164.44 sec]: learning rate : 0.000500 loss : 0.308156 +[00:36:19.495] iteration 18300 [212.07 sec]: learning rate : 0.000500 loss : 0.310939 +[00:37:07.490] iteration 18400 [260.07 sec]: learning rate : 0.000500 loss : 0.309406 +[00:37:22.688] Epoch 31 Evaluation: +[00:40:14.380] average MSE: 0.039575762074702925 average PSNR: 27.059808631648608 average SSIM: 0.729103566700112 +[00:40:46.944] iteration 18500 [32.54 sec]: learning rate : 0.000500 loss : 0.297921 +[00:41:34.639] iteration 18600 [80.24 sec]: learning rate : 0.000500 loss : 0.314928 +[00:42:22.386] iteration 18700 [128.00 sec]: learning rate : 0.000500 loss : 0.290971 +[00:43:11.211] iteration 18800 [176.81 sec]: learning rate : 0.000500 loss : 0.313285 +[00:43:58.943] iteration 18900 [224.54 sec]: learning rate : 0.000500 loss : 14.628969 +[00:44:46.500] iteration 19000 [272.10 sec]: learning rate : 0.000500 loss : 0.359118 +[00:44:50.299] Epoch 32 Evaluation: +[00:47:46.712] average MSE: 0.05967490656778383 average PSNR: 25.25828122824997 average SSIM: 0.6583714020356505 +[00:48:31.052] iteration 19100 [44.32 sec]: learning rate : 0.000500 loss : 0.310000 +[00:49:19.075] iteration 19200 [92.34 sec]: learning rate : 0.000500 loss : 0.388818 +[00:50:06.482] iteration 19300 [139.75 sec]: learning rate : 0.000500 loss : 0.347395 +[00:50:54.152] iteration 19400 [187.42 sec]: learning rate : 0.000500 loss : 0.339918 +[00:51:41.910] iteration 19500 [235.17 sec]: learning rate : 0.000500 loss : 0.279169 +[00:52:22.164] Epoch 33 Evaluation: +[00:55:15.009] average MSE: 0.04011390113574269 average PSNR: 26.999915665841975 average SSIM: 0.7225267314806482 +[00:55:22.820] iteration 19600 [7.79 sec]: learning rate : 0.000500 loss : 0.349583 +[00:56:10.533] iteration 19700 [55.50 sec]: learning rate : 0.000500 loss : 0.301270 +[00:56:58.069] iteration 19800 [103.04 sec]: learning rate : 0.000500 loss : 0.335707 +[00:57:45.469] iteration 19900 [150.44 sec]: learning rate : 0.000500 loss : 0.246654 +[00:58:33.078] iteration 20000 [198.05 sec]: learning rate : 0.000125 loss : 0.415908 +[00:58:33.239] save model to model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/iter_20000.pth +[00:59:20.853] iteration 20100 [245.82 sec]: learning rate : 0.000250 loss : 0.317764 +[00:59:49.340] Epoch 34 Evaluation: +[01:02:42.357] average MSE: 0.04430757746882354 average PSNR: 26.570994201658216 average SSIM: 0.7015537962824271 +[01:03:01.534] iteration 20200 [19.15 sec]: learning rate : 0.000250 loss : 0.332954 +[01:03:49.132] iteration 20300 [66.75 sec]: learning rate : 0.000250 loss : 0.372904 +[01:04:36.564] iteration 20400 [114.18 sec]: learning rate : 0.000250 loss : 0.343529 +[01:05:24.078] iteration 20500 [161.70 sec]: learning rate : 0.000250 loss : 0.308355 +[01:06:12.127] iteration 20600 [209.75 sec]: learning rate : 0.000250 loss : 0.313423 +[01:06:59.991] iteration 20700 [257.61 sec]: learning rate : 0.000250 loss : 0.410260 +[01:07:17.215] Epoch 35 Evaluation: +[01:10:15.205] average MSE: 0.04657735086558031 average PSNR: 26.347560627009475 average SSIM: 0.6866379942554958 +[01:10:45.908] iteration 20800 [30.68 sec]: learning rate : 0.000250 loss : 0.339923 +[01:11:33.959] iteration 20900 [78.73 sec]: learning rate : 0.000250 loss : 0.372296 +[01:12:21.524] iteration 21000 [126.30 sec]: learning rate : 0.000250 loss : 0.378322 +[01:13:09.640] iteration 21100 [174.41 sec]: learning rate : 0.000250 loss : 0.333136 +[01:13:57.331] iteration 21200 [222.10 sec]: learning rate : 0.000250 loss : 0.390876 +[01:14:44.965] iteration 21300 [269.74 sec]: learning rate : 0.000250 loss : 0.345687 +[01:14:50.683] Epoch 36 Evaluation: +[01:17:51.130] average MSE: 0.04046242459732118 average PSNR: 26.95887662480692 average SSIM: 0.7173828289185201 +[01:18:33.355] iteration 21400 [42.22 sec]: learning rate : 0.000250 loss : 0.340548 +[01:19:21.209] iteration 21500 [90.06 sec]: learning rate : 0.000250 loss : 0.330627 +[01:20:08.727] iteration 21600 [137.57 sec]: learning rate : 0.000250 loss : 0.316805 +[01:20:56.212] iteration 21700 [185.05 sec]: learning rate : 0.000250 loss : 0.349434 +[01:21:43.655] iteration 21800 [232.50 sec]: learning rate : 0.000250 loss : 0.291189 +[01:22:25.970] Epoch 37 Evaluation: +[01:25:19.468] average MSE: 0.04405573662763497 average PSNR: 26.590636110170706 average SSIM: 0.7047334973806887 +[01:25:25.362] iteration 21900 [5.87 sec]: learning rate : 0.000250 loss : 0.359173 +[01:26:12.900] iteration 22000 [53.41 sec]: learning rate : 0.000250 loss : 0.380685 +[01:27:00.383] iteration 22100 [100.89 sec]: learning rate : 0.000250 loss : 0.272086 +[01:27:48.266] iteration 22200 [148.77 sec]: learning rate : 0.000250 loss : 0.405956 +[01:28:36.381] iteration 22300 [196.89 sec]: learning rate : 0.000250 loss : 0.307451 +[01:29:23.933] iteration 22400 [244.44 sec]: learning rate : 0.000250 loss : 0.428795 +[01:29:54.480] Epoch 38 Evaluation: +[01:32:53.240] average MSE: 0.04790893898865694 average PSNR: 26.227202305554165 average SSIM: 0.6966467495908375 +[01:33:10.489] iteration 22500 [17.22 sec]: learning rate : 0.000250 loss : 0.295306 +[01:33:58.284] iteration 22600 [65.02 sec]: learning rate : 0.000250 loss : 0.303727 +[01:34:45.834] iteration 22700 [112.57 sec]: learning rate : 0.000250 loss : 0.242495 +[01:35:33.365] iteration 22800 [160.10 sec]: learning rate : 0.000250 loss : 0.384567 +[01:36:20.973] iteration 22900 [207.71 sec]: learning rate : 0.000250 loss : 0.373846 +[01:37:09.834] iteration 23000 [256.57 sec]: learning rate : 0.000250 loss : 0.323574 +[01:37:28.929] Epoch 39 Evaluation: +[01:40:29.997] average MSE: 0.04413200322696128 average PSNR: 26.580870482310573 average SSIM: 0.7068812243649911 +[01:40:58.866] iteration 23100 [28.85 sec]: learning rate : 0.000250 loss : 0.372645 +[01:41:46.379] iteration 23200 [76.36 sec]: learning rate : 0.000250 loss : 0.326880 +[01:42:33.998] iteration 23300 [123.98 sec]: learning rate : 0.000250 loss : 0.385134 +[01:43:21.765] iteration 23400 [171.74 sec]: learning rate : 0.000250 loss : 0.357587 +[01:44:09.197] iteration 23500 [219.18 sec]: learning rate : 0.000250 loss : 0.320089 +[01:44:57.376] iteration 23600 [267.36 sec]: learning rate : 0.000250 loss : 0.341329 +[01:45:05.006] Epoch 40 Evaluation: +[01:47:57.503] average MSE: 0.04554665667003532 average PSNR: 26.44510490326057 average SSIM: 0.6978591246714825 +[01:48:37.693] iteration 23700 [40.17 sec]: learning rate : 0.000250 loss : 0.329959 +[01:49:25.547] iteration 23800 [88.02 sec]: learning rate : 0.000250 loss : 0.338303 +[01:50:13.065] iteration 23900 [135.54 sec]: learning rate : 0.000250 loss : 0.247601 +[01:51:01.069] iteration 24000 [183.54 sec]: learning rate : 0.000250 loss : 0.237181 +[01:51:48.604] iteration 24100 [231.08 sec]: learning rate : 0.000250 loss : 8.774884 +[01:52:32.302] Epoch 41 Evaluation: +[01:55:25.749] average MSE: 0.05362989275386815 average PSNR: 25.737739682092016 average SSIM: 0.6781263472589788 +[01:55:29.747] iteration 24200 [3.98 sec]: learning rate : 0.000250 loss : 0.309896 +[01:56:17.869] iteration 24300 [52.10 sec]: learning rate : 0.000250 loss : 0.314733 +[01:57:05.473] iteration 24400 [99.70 sec]: learning rate : 0.000250 loss : 0.386326 +[01:57:53.039] iteration 24500 [147.27 sec]: learning rate : 0.000250 loss : 0.384576 +[01:58:40.534] iteration 24600 [194.76 sec]: learning rate : 0.000250 loss : 0.270798 +[01:59:28.288] iteration 24700 [242.52 sec]: learning rate : 0.000250 loss : 0.346601 +[02:00:00.637] Epoch 42 Evaluation: +[02:03:00.167] average MSE: 0.04675239164206619 average PSNR: 26.33188761426863 average SSIM: 0.694801657559839 +[02:03:15.774] iteration 24800 [15.61 sec]: learning rate : 0.000250 loss : 0.345962 +[02:04:03.580] iteration 24900 [63.39 sec]: learning rate : 0.000250 loss : 0.352576 +[02:04:51.210] iteration 25000 [111.02 sec]: learning rate : 0.000250 loss : 0.327208 +[02:05:38.734] iteration 25100 [158.54 sec]: learning rate : 0.000250 loss : 0.303556 +[02:06:26.389] iteration 25200 [206.20 sec]: learning rate : 0.000250 loss : 0.280243 +[02:07:14.925] iteration 25300 [254.73 sec]: learning rate : 0.000250 loss : 0.304870 +[02:07:35.819] Epoch 43 Evaluation: +[02:10:27.577] average MSE: 0.06158564816277247 average PSNR: 25.130236999302678 average SSIM: 0.6547397780289049 +[02:10:54.329] iteration 25400 [26.73 sec]: learning rate : 0.000250 loss : 0.396573 +[02:11:42.077] iteration 25500 [74.48 sec]: learning rate : 0.000250 loss : 0.287495 +[02:12:30.362] iteration 25600 [122.76 sec]: learning rate : 0.000250 loss : 0.300030 +[02:13:18.937] iteration 25700 [171.34 sec]: learning rate : 0.000250 loss : 0.431332 +[02:14:06.617] iteration 25800 [219.03 sec]: learning rate : 0.000250 loss : 0.306636 +[02:14:54.254] iteration 25900 [266.65 sec]: learning rate : 0.000250 loss : 0.307573 +[02:15:03.774] Epoch 44 Evaluation: +[02:18:03.473] average MSE: 0.044016125797775744 average PSNR: 26.59239127633583 average SSIM: 0.7026899427804811 +[02:18:42.407] iteration 26000 [38.91 sec]: learning rate : 0.000250 loss : 0.373672 +[02:19:30.710] iteration 26100 [87.22 sec]: learning rate : 0.000250 loss : 0.328082 +[02:20:18.254] iteration 26200 [134.76 sec]: learning rate : 0.000250 loss : 0.295701 +[02:21:06.254] iteration 26300 [182.76 sec]: learning rate : 0.000250 loss : 0.258180 +[02:21:53.855] iteration 26400 [230.36 sec]: learning rate : 0.000250 loss : 0.239893 +[02:22:39.492] Epoch 45 Evaluation: +[02:25:31.856] average MSE: 0.048272289676575875 average PSNR: 26.19246901406976 average SSIM: 0.6869065153905075 +[02:25:33.954] iteration 26500 [2.08 sec]: learning rate : 0.000250 loss : 0.282545 +[02:26:21.617] iteration 26600 [49.74 sec]: learning rate : 0.000250 loss : 0.342533 +[02:27:09.253] iteration 26700 [97.38 sec]: learning rate : 0.000250 loss : 0.283303 +[02:27:56.776] iteration 26800 [144.90 sec]: learning rate : 0.000250 loss : 0.443121 +[02:28:44.350] iteration 26900 [192.47 sec]: learning rate : 0.000250 loss : 0.399467 +[02:29:32.100] iteration 27000 [240.22 sec]: learning rate : 0.000250 loss : 0.275857 +[02:30:06.849] Epoch 46 Evaluation: +[02:32:58.871] average MSE: 0.05126376396989455 average PSNR: 25.935233563393712 average SSIM: 0.6756133964622579 +[02:33:12.420] iteration 27100 [13.53 sec]: learning rate : 0.000250 loss : 0.306499 +[02:34:00.238] iteration 27200 [61.34 sec]: learning rate : 0.000250 loss : 0.324680 +[02:34:47.887] iteration 27300 [108.99 sec]: learning rate : 0.000250 loss : 0.284917 +[02:35:35.868] iteration 27400 [156.97 sec]: learning rate : 0.000250 loss : 4.417865 +[02:36:23.438] iteration 27500 [204.54 sec]: learning rate : 0.000250 loss : 0.314850 +[02:37:10.889] iteration 27600 [251.99 sec]: learning rate : 0.000250 loss : 0.331847 +[02:37:34.204] Epoch 47 Evaluation: +[02:40:30.341] average MSE: 0.050198939530241946 average PSNR: 26.027134919493353 average SSIM: 0.683089488653361 +[02:40:55.540] iteration 27700 [25.18 sec]: learning rate : 0.000250 loss : 1.201362 +[02:41:43.147] iteration 27800 [72.78 sec]: learning rate : 0.000250 loss : 0.362903 +[02:42:30.819] iteration 27900 [120.46 sec]: learning rate : 0.000250 loss : 0.308549 +[02:43:18.461] iteration 28000 [168.10 sec]: learning rate : 0.000250 loss : 0.309510 +[02:44:06.690] iteration 28100 [216.33 sec]: learning rate : 0.000250 loss : 0.312165 +[02:44:54.123] iteration 28200 [263.76 sec]: learning rate : 0.000250 loss : 0.286505 +[02:45:05.517] Epoch 48 Evaluation: +[02:47:59.682] average MSE: 0.059001891839420235 average PSNR: 25.314568066440113 average SSIM: 0.6523377003457622 +[02:48:36.159] iteration 28300 [36.45 sec]: learning rate : 0.000250 loss : 0.542958 +[02:49:23.852] iteration 28400 [84.17 sec]: learning rate : 0.000250 loss : 0.308223 +[02:50:11.711] iteration 28500 [132.01 sec]: learning rate : 0.000250 loss : 0.374035 +[02:50:59.467] iteration 28600 [179.76 sec]: learning rate : 0.000250 loss : 0.335166 +[02:51:47.357] iteration 28700 [227.65 sec]: learning rate : 0.000250 loss : 0.331159 +[02:52:35.002] iteration 28800 [275.30 sec]: learning rate : 0.000250 loss : 0.337936 +[02:52:35.039] Epoch 49 Evaluation: +[02:55:33.447] average MSE: 0.04888784109143765 average PSNR: 26.138810060646676 average SSIM: 0.6857341646758645 +[02:56:21.266] iteration 28900 [47.80 sec]: learning rate : 0.000250 loss : 0.393140 +[02:57:09.209] iteration 29000 [95.74 sec]: learning rate : 0.000250 loss : 0.357308 +[02:57:57.097] iteration 29100 [143.63 sec]: learning rate : 0.000250 loss : 0.627793 +[02:58:44.661] iteration 29200 [191.19 sec]: learning rate : 0.000250 loss : 0.276969 +[02:59:32.234] iteration 29300 [238.76 sec]: learning rate : 0.000250 loss : 0.333909 +[03:00:08.442] Epoch 50 Evaluation: +[03:03:02.246] average MSE: 0.05431749309668862 average PSNR: 25.680831495386744 average SSIM: 0.6691663790637687 +[03:03:13.884] iteration 29400 [11.61 sec]: learning rate : 0.000250 loss : 0.360472 +[03:04:01.375] iteration 29500 [59.11 sec]: learning rate : 0.000250 loss : 0.268575 +[03:04:49.072] iteration 29600 [106.80 sec]: learning rate : 0.000250 loss : 0.317462 +[03:05:37.339] iteration 29700 [155.07 sec]: learning rate : 0.000250 loss : 0.355203 +[03:06:24.987] iteration 29800 [202.72 sec]: learning rate : 0.000250 loss : 0.309521 +[03:07:12.717] iteration 29900 [250.45 sec]: learning rate : 0.000250 loss : 0.338997 +[03:07:37.752] Epoch 51 Evaluation: +[03:10:33.506] average MSE: 0.04170021683404265 average PSNR: 26.82881360137535 average SSIM: 0.7163146215846128 +[03:10:56.607] iteration 30000 [23.08 sec]: learning rate : 0.000250 loss : 0.335182 +[03:11:44.011] iteration 30100 [70.48 sec]: learning rate : 0.000250 loss : 0.243191 +[03:12:31.600] iteration 30200 [118.07 sec]: learning rate : 0.000250 loss : 0.338867 +[03:13:19.214] iteration 30300 [165.68 sec]: learning rate : 0.000250 loss : 0.299266 +[03:14:07.789] iteration 30400 [214.28 sec]: learning rate : 0.000250 loss : 0.391970 +[03:14:56.038] iteration 30500 [262.51 sec]: learning rate : 0.000250 loss : 0.377181 +[03:15:09.381] Epoch 52 Evaluation: +[03:18:11.749] average MSE: 0.050939755727215597 average PSNR: 25.964290360665682 average SSIM: 0.6788929893607261 +[03:18:46.171] iteration 30600 [34.40 sec]: learning rate : 0.000250 loss : 0.312275 +[03:19:34.081] iteration 30700 [82.31 sec]: learning rate : 0.000250 loss : 0.411141 +[03:20:22.267] iteration 30800 [130.50 sec]: learning rate : 0.000250 loss : 0.406719 +[03:21:09.696] iteration 30900 [177.92 sec]: learning rate : 0.000250 loss : 0.324164 +[03:21:57.252] iteration 31000 [225.48 sec]: learning rate : 0.000250 loss : 0.234450 +[03:22:44.767] iteration 31100 [273.00 sec]: learning rate : 0.000250 loss : 0.354348 +[03:22:46.674] Epoch 53 Evaluation: +[03:25:39.859] average MSE: 0.05165929902982657 average PSNR: 25.902988557563614 average SSIM: 0.6744323225494961 +[03:26:25.636] iteration 31200 [45.75 sec]: learning rate : 0.000250 loss : 0.380788 +[03:27:13.278] iteration 31300 [93.40 sec]: learning rate : 0.000250 loss : 0.245452 +[03:28:00.938] iteration 31400 [141.06 sec]: learning rate : 0.000250 loss : 0.381588 +[03:28:48.535] iteration 31500 [188.66 sec]: learning rate : 0.000250 loss : 0.310405 +[03:29:36.263] iteration 31600 [236.38 sec]: learning rate : 0.000250 loss : 0.289484 +[03:30:14.305] Epoch 54 Evaluation: +[03:33:11.911] average MSE: 0.062060614380735915 average PSNR: 25.100593178796846 average SSIM: 0.6485639097125916 +[03:33:21.752] iteration 31700 [9.82 sec]: learning rate : 0.000250 loss : 0.294382 +[03:34:09.321] iteration 31800 [57.39 sec]: learning rate : 0.000250 loss : 0.315133 +[03:34:57.071] iteration 31900 [105.14 sec]: learning rate : 0.000250 loss : 0.317108 +[03:35:44.664] iteration 32000 [152.73 sec]: learning rate : 0.000250 loss : 0.270293 +[03:36:32.781] iteration 32100 [200.85 sec]: learning rate : 0.000250 loss : 0.403849 +[03:37:20.385] iteration 32200 [248.45 sec]: learning rate : 0.000250 loss : 0.533157 +[03:37:47.227] Epoch 55 Evaluation: +[03:40:37.735] average MSE: 0.06140179901668301 average PSNR: 25.147022316109215 average SSIM: 0.653099148402993 +[03:40:58.843] iteration 32300 [21.08 sec]: learning rate : 0.000250 loss : 0.401023 +[03:41:46.691] iteration 32400 [68.93 sec]: learning rate : 0.000250 loss : 0.313052 +[03:42:34.498] iteration 32500 [116.74 sec]: learning rate : 0.000250 loss : 0.347646 +[03:43:21.941] iteration 32600 [164.18 sec]: learning rate : 0.000250 loss : 0.389868 +[03:44:09.864] iteration 32700 [212.11 sec]: learning rate : 0.000250 loss : 0.337276 +[03:44:57.513] iteration 32800 [259.75 sec]: learning rate : 0.000250 loss : 0.337029 +[03:45:12.866] Epoch 56 Evaluation: +[03:48:11.885] average MSE: 0.05257348479487478 average PSNR: 25.829039394334078 average SSIM: 0.6719635730852858 +[03:48:44.559] iteration 32900 [32.65 sec]: learning rate : 0.000250 loss : 0.368345 +[03:49:32.380] iteration 33000 [80.47 sec]: learning rate : 0.000250 loss : 0.291369 +[03:50:20.929] iteration 33100 [129.02 sec]: learning rate : 0.000250 loss : 0.334224 +[03:51:08.658] iteration 33200 [176.75 sec]: learning rate : 0.000250 loss : 0.313720 +[03:51:56.457] iteration 33300 [224.55 sec]: learning rate : 0.000250 loss : 0.407735 +[03:52:44.116] iteration 33400 [272.21 sec]: learning rate : 0.000250 loss : 0.370358 +[03:52:47.934] Epoch 57 Evaluation: +[03:55:48.851] average MSE: 0.05167834639612688 average PSNR: 25.903678697444665 average SSIM: 0.6757185760553782 +[03:56:33.436] iteration 33500 [44.56 sec]: learning rate : 0.000250 loss : 0.287597 +[03:57:21.141] iteration 33600 [92.27 sec]: learning rate : 0.000250 loss : 0.385074 +[03:58:08.756] iteration 33700 [139.88 sec]: learning rate : 0.000250 loss : 0.320299 +[03:58:56.959] iteration 33800 [188.09 sec]: learning rate : 0.000250 loss : 0.333537 +[03:59:44.868] iteration 33900 [235.99 sec]: learning rate : 0.000250 loss : 0.291803 +[04:00:24.828] Epoch 58 Evaluation: +[04:03:20.312] average MSE: 0.05263230875188429 average PSNR: 25.816201208723797 average SSIM: 0.6749640192461109 +[04:03:28.121] iteration 34000 [7.79 sec]: learning rate : 0.000250 loss : 0.344046 +[04:04:15.778] iteration 34100 [55.44 sec]: learning rate : 0.000250 loss : 0.369877 +[04:05:03.757] iteration 34200 [103.42 sec]: learning rate : 0.000250 loss : 0.301078 +[04:05:51.276] iteration 34300 [150.94 sec]: learning rate : 0.000250 loss : 0.295965 +[04:06:38.921] iteration 34400 [198.59 sec]: learning rate : 0.000250 loss : 0.342335 +[04:07:26.566] iteration 34500 [246.23 sec]: learning rate : 0.000250 loss : 0.244563 +[04:07:55.839] Epoch 59 Evaluation: +[04:10:56.803] average MSE: 0.05902971056220745 average PSNR: 25.322591921783985 average SSIM: 0.6555104647434816 +[04:11:15.997] iteration 34600 [19.17 sec]: learning rate : 0.000250 loss : 0.289545 +[04:12:03.615] iteration 34700 [66.79 sec]: learning rate : 0.000250 loss : 0.359726 +[04:12:51.115] iteration 34800 [114.29 sec]: learning rate : 0.000250 loss : 0.293254 +[04:13:38.626] iteration 34900 [161.80 sec]: learning rate : 0.000250 loss : 0.369365 +[04:14:26.514] iteration 35000 [209.69 sec]: learning rate : 0.000250 loss : 0.284519 +[04:15:14.070] iteration 35100 [257.24 sec]: learning rate : 0.000250 loss : 0.401737 +[04:15:31.152] Epoch 60 Evaluation: +[04:18:24.319] average MSE: 0.06019259377555253 average PSNR: 25.233105523028 average SSIM: 0.6524803276796565 +[04:18:54.999] iteration 35200 [30.66 sec]: learning rate : 0.000250 loss : 0.402852 +[04:19:42.408] iteration 35300 [78.07 sec]: learning rate : 0.000250 loss : 0.363003 +[04:20:30.321] iteration 35400 [125.98 sec]: learning rate : 0.000250 loss : 0.290175 +[04:21:18.164] iteration 35500 [173.82 sec]: learning rate : 0.000250 loss : 0.362132 +[04:22:05.740] iteration 35600 [221.40 sec]: learning rate : 0.000250 loss : 0.350807 +[04:22:53.405] iteration 35700 [269.06 sec]: learning rate : 0.000250 loss : 0.360398 +[04:22:59.109] Epoch 61 Evaluation: +[04:25:56.974] average MSE: 0.05431887983760384 average PSNR: 25.68501605059669 average SSIM: 0.6663760829480011 +[04:26:39.554] iteration 35800 [42.56 sec]: learning rate : 0.000250 loss : 0.309338 +[04:27:27.504] iteration 35900 [90.51 sec]: learning rate : 0.000250 loss : 0.378363 +[04:28:15.265] iteration 36000 [138.27 sec]: learning rate : 0.000250 loss : 0.337592 +[04:29:03.006] iteration 36100 [186.01 sec]: learning rate : 0.000250 loss : 0.308382 +[04:29:50.685] iteration 36200 [233.69 sec]: learning rate : 0.000250 loss : 0.324467 +[04:30:32.700] Epoch 62 Evaluation: +[04:33:26.143] average MSE: 0.057857622531017575 average PSNR: 25.406145362884526 average SSIM: 0.6575396527500273 +[04:33:32.030] iteration 36300 [5.86 sec]: learning rate : 0.000250 loss : 0.355723 +[04:34:19.569] iteration 36400 [53.40 sec]: learning rate : 0.000250 loss : 0.334321 +[04:35:07.616] iteration 36500 [101.45 sec]: learning rate : 0.000250 loss : 0.260062 +[04:35:55.161] iteration 36600 [149.00 sec]: learning rate : 0.000250 loss : 0.430419 +[04:36:42.629] iteration 36700 [196.46 sec]: learning rate : 0.000250 loss : 0.381121 +[04:37:30.176] iteration 36800 [244.01 sec]: learning rate : 0.000250 loss : 0.391122 +[04:38:00.822] Epoch 63 Evaluation: +[04:40:58.887] average MSE: 0.07181905310099682 average PSNR: 24.457136972933316 average SSIM: 0.625969407034258 +[04:41:16.304] iteration 36900 [17.40 sec]: learning rate : 0.000250 loss : 0.276887 +[04:42:03.846] iteration 37000 [64.94 sec]: learning rate : 0.000250 loss : 0.328387 +[04:42:51.416] iteration 37100 [112.51 sec]: learning rate : 0.000250 loss : 0.280029 +[04:43:40.366] iteration 37200 [161.46 sec]: learning rate : 0.000250 loss : 0.384261 +[04:44:28.371] iteration 37300 [209.46 sec]: learning rate : 0.000250 loss : 0.356430 +[04:45:16.080] iteration 37400 [257.17 sec]: learning rate : 0.000250 loss : 0.273397 +[04:45:35.097] Epoch 64 Evaluation: +[04:48:26.872] average MSE: 0.06824435388457288 average PSNR: 24.682884769591755 average SSIM: 0.6310780140029288 +[04:48:55.495] iteration 37500 [28.60 sec]: learning rate : 0.000250 loss : 0.323059 +[04:49:43.462] iteration 37600 [76.57 sec]: learning rate : 0.000250 loss : 0.239050 +[04:50:31.595] iteration 37700 [124.70 sec]: learning rate : 0.000250 loss : 0.429137 +[04:51:19.033] iteration 37800 [172.14 sec]: learning rate : 0.000250 loss : 0.350343 +[04:52:06.649] iteration 37900 [219.75 sec]: learning rate : 0.000250 loss : 0.280267 +[04:52:54.562] iteration 38000 [267.67 sec]: learning rate : 0.000250 loss : 0.306713 +[04:53:02.156] Epoch 65 Evaluation: +[04:55:54.726] average MSE: 0.0733262618935935 average PSNR: 24.357183243607917 average SSIM: 0.6233283396057038 +[04:56:35.297] iteration 38100 [40.56 sec]: learning rate : 0.000250 loss : 0.296017 +[04:57:22.951] iteration 38200 [88.20 sec]: learning rate : 0.000250 loss : 0.328738 +[04:58:10.644] iteration 38300 [135.90 sec]: learning rate : 0.000250 loss : 1.216481 +[04:58:58.252] iteration 38400 [183.50 sec]: learning rate : 0.000250 loss : 0.267497 +[04:59:46.003] iteration 38500 [231.25 sec]: learning rate : 0.000250 loss : 0.336527 +[05:00:30.098] Epoch 66 Evaluation: +[05:03:25.462] average MSE: 0.07904319390923412 average PSNR: 24.038627972848175 average SSIM: 0.6259480902335873 +[05:03:29.475] iteration 38600 [3.99 sec]: learning rate : 0.000250 loss : 0.300353 +[05:04:17.147] iteration 38700 [51.66 sec]: learning rate : 0.000250 loss : 0.249775 +[05:05:04.761] iteration 38800 [99.28 sec]: learning rate : 0.000250 loss : 0.375181 +[05:05:53.033] iteration 38900 [147.55 sec]: learning rate : 0.000250 loss : 0.390855 +[05:06:40.843] iteration 39000 [195.36 sec]: learning rate : 0.000250 loss : 0.305016 +[05:07:28.425] iteration 39100 [242.94 sec]: learning rate : 0.000250 loss : 0.406348 +[05:08:00.731] Epoch 67 Evaluation: +[05:10:57.978] average MSE: 0.08947222912916235 average PSNR: 23.485979234769935 average SSIM: 0.6073411963291664 +[05:11:13.409] iteration 39200 [15.41 sec]: learning rate : 0.000250 loss : 0.405341 +[05:12:01.546] iteration 39300 [63.55 sec]: learning rate : 0.000250 loss : 0.348744 +[05:12:49.202] iteration 39400 [111.20 sec]: learning rate : 0.000250 loss : 0.342022 +[05:13:36.776] iteration 39500 [158.78 sec]: learning rate : 0.000250 loss : 0.318677 +[05:14:24.494] iteration 39600 [206.49 sec]: learning rate : 0.000250 loss : 0.302515 +[05:15:12.495] iteration 39700 [254.49 sec]: learning rate : 0.000250 loss : 0.372414 +[05:15:33.488] Epoch 68 Evaluation: +[05:18:25.951] average MSE: 0.08331375425902739 average PSNR: 23.805735956734928 average SSIM: 0.6073770105513288 +[05:18:53.086] iteration 39800 [27.13 sec]: learning rate : 0.000250 loss : 0.361283 +[05:19:40.840] iteration 39900 [74.87 sec]: learning rate : 0.000250 loss : 0.290943 +[05:20:28.264] iteration 40000 [122.29 sec]: learning rate : 0.000063 loss : 0.330806 +[05:20:28.421] save model to model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/iter_40000.pth +[05:21:16.468] iteration 40100 [170.50 sec]: learning rate : 0.000125 loss : 0.395671 +[05:22:04.201] iteration 40200 [218.23 sec]: learning rate : 0.000125 loss : 0.261441 +[05:22:52.314] iteration 40300 [266.34 sec]: learning rate : 0.000125 loss : 0.282119 +[05:23:01.844] Epoch 69 Evaluation: +[05:25:54.674] average MSE: 0.08309507055821243 average PSNR: 23.819006628774808 average SSIM: 0.6155098147483944 +[05:26:33.026] iteration 40400 [38.33 sec]: learning rate : 0.000125 loss : 0.366432 +[05:27:21.019] iteration 40500 [86.32 sec]: learning rate : 0.000125 loss : 0.362859 +[05:28:09.319] iteration 40600 [134.63 sec]: learning rate : 0.000125 loss : 16.737963 +[05:28:57.033] iteration 40700 [182.34 sec]: learning rate : 0.000125 loss : 0.248844 +[05:29:44.775] iteration 40800 [230.08 sec]: learning rate : 0.000125 loss : 0.322568 +[05:30:30.541] Epoch 70 Evaluation: +[05:33:27.080] average MSE: 0.06392631667837598 average PSNR: 24.98020716521069 average SSIM: 0.642986599271846 +[05:33:29.195] iteration 40900 [2.09 sec]: learning rate : 0.000125 loss : 0.318055 +[05:34:17.442] iteration 41000 [50.34 sec]: learning rate : 0.000125 loss : 0.313992 +[05:35:05.007] iteration 41100 [97.90 sec]: learning rate : 0.000125 loss : 0.293710 +[05:35:52.710] iteration 41200 [145.61 sec]: learning rate : 0.000125 loss : 0.397040 +[05:36:40.482] iteration 41300 [193.38 sec]: learning rate : 0.000125 loss : 0.425627 +[05:37:27.947] iteration 41400 [240.84 sec]: learning rate : 0.000125 loss : 0.275752 +[05:38:02.213] Epoch 71 Evaluation: +[05:40:54.775] average MSE: 0.08098695425400268 average PSNR: 23.935610559805404 average SSIM: 0.6202062500347132 +[05:41:08.265] iteration 41500 [13.47 sec]: learning rate : 0.000125 loss : 0.345261 +[05:41:55.964] iteration 41600 [61.18 sec]: learning rate : 0.000125 loss : 0.364337 +[05:42:43.517] iteration 41700 [108.72 sec]: learning rate : 0.000125 loss : 0.306433 +[05:43:31.008] iteration 41800 [156.21 sec]: learning rate : 0.000125 loss : 0.349235 +[05:44:18.507] iteration 41900 [203.71 sec]: learning rate : 0.000125 loss : 0.281104 +[05:45:06.986] iteration 42000 [252.19 sec]: learning rate : 0.000125 loss : 4.187071 +[05:45:29.809] Epoch 72 Evaluation: +[05:48:23.256] average MSE: 0.07618223210049589 average PSNR: 24.199959190776724 average SSIM: 0.6232629226063029 +[05:48:48.289] iteration 42100 [25.01 sec]: learning rate : 0.000125 loss : 0.304946 +[05:49:35.732] iteration 42200 [72.45 sec]: learning rate : 0.000125 loss : 0.344699 +[05:50:23.246] iteration 42300 [119.97 sec]: learning rate : 0.000125 loss : 0.347621 +[05:51:11.769] iteration 42400 [168.49 sec]: learning rate : 0.000125 loss : 0.263111 +[05:51:59.226] iteration 42500 [215.95 sec]: learning rate : 0.000125 loss : 0.304379 +[05:52:46.748] iteration 42600 [263.47 sec]: learning rate : 0.000125 loss : 0.292142 +[05:52:58.145] Epoch 73 Evaluation: +[05:55:50.113] average MSE: 0.10377610740128307 average PSNR: 22.840409565021066 average SSIM: 0.6086074374490249 +[05:56:26.931] iteration 42700 [36.79 sec]: learning rate : 0.000125 loss : 0.274619 +[05:57:14.638] iteration 42800 [84.50 sec]: learning rate : 0.000125 loss : 0.308438 +[05:58:02.198] iteration 42900 [132.06 sec]: learning rate : 0.000125 loss : 0.330449 +[05:58:49.770] iteration 43000 [179.63 sec]: learning rate : 0.000125 loss : 0.314506 +[05:59:37.346] iteration 43100 [227.21 sec]: learning rate : 0.000125 loss : 0.338239 +[06:00:24.853] iteration 43200 [274.72 sec]: learning rate : 0.000125 loss : 0.339500 +[06:00:24.890] Epoch 74 Evaluation: +[06:03:16.915] average MSE: 0.07557617345539733 average PSNR: 24.23913918300305 average SSIM: 0.6198803689087165 +[06:04:05.082] iteration 43300 [48.14 sec]: learning rate : 0.000125 loss : 0.329102 +[06:04:52.467] iteration 43400 [95.53 sec]: learning rate : 0.000125 loss : 0.314495 +[06:05:39.999] iteration 43500 [143.06 sec]: learning rate : 0.000125 loss : 0.347698 +[06:06:27.540] iteration 43600 [190.60 sec]: learning rate : 0.000125 loss : 0.251182 +[06:07:15.292] iteration 43700 [238.35 sec]: learning rate : 0.000125 loss : 0.347466 +[06:07:51.660] Epoch 75 Evaluation: +[06:10:47.842] average MSE: 0.06741794076583255 average PSNR: 24.738362059488562 average SSIM: 0.6355396016766169 +[06:10:59.426] iteration 43800 [11.56 sec]: learning rate : 0.000125 loss : 0.340769 +[06:11:46.859] iteration 43900 [58.99 sec]: learning rate : 0.000125 loss : 0.313148 +[06:12:35.106] iteration 44000 [107.24 sec]: learning rate : 0.000125 loss : 0.451032 +[06:13:22.981] iteration 44100 [155.11 sec]: learning rate : 0.000125 loss : 0.401558 +[06:14:10.411] iteration 44200 [202.54 sec]: learning rate : 0.000125 loss : 0.248683 +[06:14:58.006] iteration 44300 [250.14 sec]: learning rate : 0.000125 loss : 0.332465 +[06:15:23.091] Epoch 76 Evaluation: +[06:18:16.370] average MSE: 0.10485240273496221 average PSNR: 22.795851548632424 average SSIM: 0.6110987018048483 +[06:18:40.108] iteration 44400 [23.71 sec]: learning rate : 0.000125 loss : 0.359116 +[06:19:27.505] iteration 44500 [71.11 sec]: learning rate : 0.000125 loss : 0.277397 +[06:20:15.067] iteration 44600 [118.67 sec]: learning rate : 0.000125 loss : 0.313037 +[06:21:02.763] iteration 44700 [166.37 sec]: learning rate : 0.000125 loss : 0.304897 +[06:21:51.011] iteration 44800 [214.62 sec]: learning rate : 0.000125 loss : 0.406189 +[06:22:38.612] iteration 44900 [262.22 sec]: learning rate : 0.000125 loss : 0.354335 +[06:22:51.912] Epoch 77 Evaluation: +[06:25:48.156] average MSE: 0.07898292350730444 average PSNR: 24.0444903580632 average SSIM: 0.6121242837487405 +[06:26:22.501] iteration 45000 [34.32 sec]: learning rate : 0.000125 loss : 0.291974 +[06:27:10.079] iteration 45100 [81.90 sec]: learning rate : 0.000125 loss : 0.406747 +[06:27:57.904] iteration 45200 [129.72 sec]: learning rate : 0.000125 loss : 0.325827 +[06:28:45.332] iteration 45300 [177.15 sec]: learning rate : 0.000125 loss : 0.303848 +[06:29:33.449] iteration 45400 [225.29 sec]: learning rate : 0.000125 loss : 0.303044 +[06:30:21.502] iteration 45500 [273.32 sec]: learning rate : 0.000125 loss : 0.335152 +[06:30:23.410] Epoch 78 Evaluation: +[06:33:15.804] average MSE: 0.07674502742194908 average PSNR: 24.172422534167442 average SSIM: 0.6262531469183948 +[06:34:01.944] iteration 45600 [46.13 sec]: learning rate : 0.000125 loss : 0.395570 +[06:34:49.484] iteration 45700 [93.66 sec]: learning rate : 0.000125 loss : 0.262133 +[06:35:37.446] iteration 45800 [141.62 sec]: learning rate : 0.000125 loss : 0.351936 +[06:36:24.948] iteration 45900 [189.12 sec]: learning rate : 0.000125 loss : 0.271613 +[06:37:12.646] iteration 46000 [236.82 sec]: learning rate : 0.000125 loss : 0.330066 +[06:37:50.652] Epoch 79 Evaluation: +[06:40:44.113] average MSE: 0.07883388009416338 average PSNR: 24.054350688726036 average SSIM: 0.6197475034172749 +[06:40:53.809] iteration 46100 [9.67 sec]: learning rate : 0.000125 loss : 0.247677 +[06:41:41.382] iteration 46200 [57.25 sec]: learning rate : 0.000125 loss : 0.295583 +[06:42:28.927] iteration 46300 [104.79 sec]: learning rate : 0.000125 loss : 0.320742 +[06:43:16.356] iteration 46400 [152.22 sec]: learning rate : 0.000125 loss : 0.253537 +[06:44:03.963] iteration 46500 [199.83 sec]: learning rate : 0.000125 loss : 0.281079 +[06:44:51.540] iteration 46600 [247.40 sec]: learning rate : 0.000125 loss : 0.348127 +[06:45:18.098] Epoch 80 Evaluation: +[06:48:11.261] average MSE: 0.07822700771197137 average PSNR: 24.089454894999523 average SSIM: 0.6231309106277675 +[06:48:32.324] iteration 46700 [21.04 sec]: learning rate : 0.000125 loss : 0.358059 +[06:49:19.910] iteration 46800 [68.63 sec]: learning rate : 0.000125 loss : 0.301462 +[06:50:07.439] iteration 46900 [116.16 sec]: learning rate : 0.000125 loss : 0.314397 +[06:50:54.867] iteration 47000 [163.58 sec]: learning rate : 0.000125 loss : 0.290838 +[06:51:42.972] iteration 47100 [211.69 sec]: learning rate : 0.000125 loss : 0.249148 +[06:52:30.965] iteration 47200 [259.68 sec]: learning rate : 0.000125 loss : 0.371296 +[06:52:46.328] Epoch 81 Evaluation: +[06:55:38.003] average MSE: 0.1060817138078806 average PSNR: 22.744166000467164 average SSIM: 0.6106482207379905 +[06:56:10.585] iteration 47300 [32.56 sec]: learning rate : 0.000125 loss : 0.293067 +[06:56:58.712] iteration 47400 [80.69 sec]: learning rate : 0.000125 loss : 0.276415 +[06:57:47.626] iteration 47500 [129.60 sec]: learning rate : 0.000125 loss : 0.329457 +[06:58:35.220] iteration 47600 [177.19 sec]: learning rate : 0.000125 loss : 0.326354 +[06:59:22.922] iteration 47700 [224.90 sec]: learning rate : 0.000125 loss : 0.393516 +[07:00:10.480] iteration 47800 [272.45 sec]: learning rate : 0.000125 loss : 0.287811 +[07:00:14.288] Epoch 82 Evaluation: +[07:03:06.980] average MSE: 0.08819612542337056 average PSNR: 23.55984180314861 average SSIM: 0.6077354343181461 +[07:03:51.323] iteration 47900 [44.32 sec]: learning rate : 0.000125 loss : 0.260296 +[07:04:38.918] iteration 48000 [91.91 sec]: learning rate : 0.000125 loss : 0.307709 +[07:05:26.340] iteration 48100 [139.34 sec]: learning rate : 0.000125 loss : 0.288247 +[07:06:14.314] iteration 48200 [187.31 sec]: learning rate : 0.000125 loss : 0.304000 +[07:07:01.923] iteration 48300 [234.92 sec]: learning rate : 0.000125 loss : 0.276874 +[07:07:41.880] Epoch 83 Evaluation: +[07:10:37.044] average MSE: 0.08234128277109429 average PSNR: 23.857193617435204 average SSIM: 0.6218921521575146 +[07:10:44.854] iteration 48400 [7.79 sec]: learning rate : 0.000125 loss : 0.340928 +[07:11:32.442] iteration 48500 [55.37 sec]: learning rate : 0.000125 loss : 0.339700 +[07:12:19.907] iteration 48600 [102.84 sec]: learning rate : 0.000125 loss : 0.345045 +[07:13:07.418] iteration 48700 [150.35 sec]: learning rate : 0.000125 loss : 0.231365 +[07:13:54.996] iteration 48800 [197.93 sec]: learning rate : 0.000125 loss : 0.355150 +[07:14:42.983] iteration 48900 [245.92 sec]: learning rate : 0.000125 loss : 0.277130 +[07:15:12.126] Epoch 84 Evaluation: +[07:18:11.409] average MSE: 0.08082092138275258 average PSNR: 23.94326260740048 average SSIM: 0.6176250109847082 +[07:18:30.627] iteration 49000 [19.20 sec]: learning rate : 0.000125 loss : 0.303154 +[07:19:18.271] iteration 49100 [66.84 sec]: learning rate : 0.000125 loss : 0.342049 +[07:20:06.165] iteration 49200 [114.73 sec]: learning rate : 0.000125 loss : 0.315781 +[07:20:53.681] iteration 49300 [162.25 sec]: learning rate : 0.000125 loss : 0.325738 +[07:21:41.257] iteration 49400 [209.82 sec]: learning rate : 0.000125 loss : 0.299930 +[07:22:29.377] iteration 49500 [257.95 sec]: learning rate : 0.000125 loss : 0.374799 +[07:22:46.497] Epoch 85 Evaluation: +[07:25:44.118] average MSE: 0.1042423325544128 average PSNR: 22.820861669390098 average SSIM: 0.6133272138217841 +[07:26:14.927] iteration 49600 [30.79 sec]: learning rate : 0.000125 loss : 0.386529 +[07:27:02.502] iteration 49700 [78.36 sec]: learning rate : 0.000125 loss : 0.382742 +[07:27:50.552] iteration 49800 [126.41 sec]: learning rate : 0.000125 loss : 0.366645 +[07:28:38.088] iteration 49900 [173.95 sec]: learning rate : 0.000125 loss : 0.375428 +[07:29:25.582] iteration 50000 [221.44 sec]: learning rate : 0.000125 loss : 0.347116 +[07:30:13.135] iteration 50100 [268.99 sec]: learning rate : 0.000125 loss : 0.332396 +[07:30:18.833] Epoch 86 Evaluation: +[07:33:12.549] average MSE: 0.09249124310784305 average PSNR: 23.345652474356015 average SSIM: 0.6128495294842529 +[07:33:54.884] iteration 50200 [42.31 sec]: learning rate : 0.000125 loss : 0.299811 +[07:34:42.426] iteration 50300 [89.85 sec]: learning rate : 0.000125 loss : 0.297577 +[07:35:29.982] iteration 50400 [137.41 sec]: learning rate : 0.000125 loss : 0.328777 +[07:36:17.532] iteration 50500 [184.96 sec]: learning rate : 0.000125 loss : 0.325035 +[07:37:05.394] iteration 50600 [232.82 sec]: learning rate : 0.000125 loss : 0.336013 +[07:37:47.230] Epoch 87 Evaluation: +[07:40:38.237] average MSE: 0.1015773886589643 average PSNR: 22.935932992087437 average SSIM: 0.60893206879064 +[07:40:44.123] iteration 50700 [5.86 sec]: learning rate : 0.000125 loss : 0.357583 +[07:41:31.529] iteration 50800 [53.27 sec]: learning rate : 0.000125 loss : 0.320205 +[07:42:20.139] iteration 50900 [101.88 sec]: learning rate : 0.000125 loss : 0.309823 +[07:43:07.807] iteration 51000 [149.55 sec]: learning rate : 0.000125 loss : 0.380958 +[07:43:55.399] iteration 51100 [197.14 sec]: learning rate : 0.000125 loss : 0.289181 +[07:44:43.033] iteration 51200 [244.77 sec]: learning rate : 0.000125 loss : 0.414455 +[07:45:13.491] Epoch 88 Evaluation: +[07:48:08.483] average MSE: 0.07905765663211002 average PSNR: 24.040374202481114 average SSIM: 0.6280081052870635 +[07:48:25.935] iteration 51300 [17.43 sec]: learning rate : 0.000125 loss : 0.290548 +[07:49:13.365] iteration 51400 [64.86 sec]: learning rate : 0.000125 loss : 0.305067 +[07:50:00.969] iteration 51500 [112.46 sec]: learning rate : 0.000125 loss : 0.306299 +[07:50:49.000] iteration 51600 [160.49 sec]: learning rate : 0.000125 loss : 0.372743 +[07:51:36.753] iteration 51700 [208.25 sec]: learning rate : 0.000125 loss : 0.692765 +[07:52:24.888] iteration 51800 [256.38 sec]: learning rate : 0.000125 loss : 0.251906 +[07:52:43.963] Epoch 89 Evaluation: +[07:55:43.673] average MSE: 0.08707405348458387 average PSNR: 23.611626176762634 average SSIM: 0.6114362348315063 +[07:56:12.443] iteration 51900 [28.75 sec]: learning rate : 0.000125 loss : 0.434605 +[07:57:00.220] iteration 52000 [76.52 sec]: learning rate : 0.000125 loss : 0.271212 +[07:57:47.856] iteration 52100 [124.16 sec]: learning rate : 0.000125 loss : 0.445839 +[07:58:35.685] iteration 52200 [171.99 sec]: learning rate : 0.000125 loss : 0.408202 +[07:59:24.083] iteration 52300 [220.39 sec]: learning rate : 0.000125 loss : 0.416918 +[08:00:11.638] iteration 52400 [267.94 sec]: learning rate : 0.000125 loss : 0.299927 +[08:00:19.239] Epoch 90 Evaluation: +[08:03:11.404] average MSE: 0.090173458832058 average PSNR: 23.45493762424633 average SSIM: 0.6102140762743327 +[08:03:51.458] iteration 52500 [40.03 sec]: learning rate : 0.000125 loss : 0.349693 +[08:04:39.860] iteration 52600 [88.43 sec]: learning rate : 0.000125 loss : 0.371473 +[08:05:27.387] iteration 52700 [135.96 sec]: learning rate : 0.000125 loss : 0.271418 +[08:06:15.103] iteration 52800 [183.68 sec]: learning rate : 0.000125 loss : 0.269991 +[08:07:02.953] iteration 52900 [231.53 sec]: learning rate : 0.000125 loss : 0.352167 +[08:07:46.868] Epoch 91 Evaluation: +[08:10:45.485] average MSE: 0.07424496999346278 average PSNR: 24.318217812449603 average SSIM: 0.6226307702577438 +[08:10:49.486] iteration 53000 [3.98 sec]: learning rate : 0.000125 loss : 0.332569 +[08:11:37.047] iteration 53100 [51.54 sec]: learning rate : 0.000125 loss : 0.303397 +[08:12:24.706] iteration 53200 [99.20 sec]: learning rate : 0.000125 loss : 0.285822 +[08:13:12.314] iteration 53300 [146.81 sec]: learning rate : 0.000125 loss : 0.351266 +[08:14:00.050] iteration 53400 [194.54 sec]: learning rate : 0.000125 loss : 0.312587 +[08:14:47.636] iteration 53500 [242.13 sec]: learning rate : 0.000125 loss : 0.349276 +[08:15:20.002] Epoch 92 Evaluation: +[08:18:15.473] average MSE: 0.08608857256803151 average PSNR: 23.665336062231535 average SSIM: 0.6166021626370024 +[08:18:30.854] iteration 53600 [15.36 sec]: learning rate : 0.000125 loss : 0.288139 +[08:19:18.421] iteration 53700 [62.93 sec]: learning rate : 0.000125 loss : 0.326977 +[08:20:05.944] iteration 53800 [110.45 sec]: learning rate : 0.000125 loss : 0.318011 +[08:20:53.515] iteration 53900 [158.02 sec]: learning rate : 0.000125 loss : 0.296250 +[08:21:41.659] iteration 54000 [206.16 sec]: learning rate : 0.000125 loss : 0.325383 +[08:22:29.587] iteration 54100 [254.09 sec]: learning rate : 0.000125 loss : 0.323674 +[08:22:50.596] Epoch 93 Evaluation: +[08:25:41.728] average MSE: 0.05838456751877833 average PSNR: 25.370093414029093 average SSIM: 0.6518848405556935 +[08:26:08.453] iteration 54200 [26.70 sec]: learning rate : 0.000125 loss : 0.424835 +[08:26:57.004] iteration 54300 [75.25 sec]: learning rate : 0.000125 loss : 0.327455 +[08:27:44.416] iteration 54400 [122.66 sec]: learning rate : 0.000125 loss : 0.276883 +[08:28:32.439] iteration 54500 [170.69 sec]: learning rate : 0.000125 loss : 0.357387 +[08:29:20.170] iteration 54600 [218.42 sec]: learning rate : 0.000125 loss : 0.303264 +[08:30:07.738] iteration 54700 [265.99 sec]: learning rate : 0.000125 loss : 0.316813 +[08:30:17.260] Epoch 94 Evaluation: +[08:33:11.922] average MSE: 0.08439695791908157 average PSNR: 23.75394840551904 average SSIM: 0.6121431321110187 +[08:33:50.254] iteration 54800 [38.31 sec]: learning rate : 0.000125 loss : 0.345904 +[08:34:38.376] iteration 54900 [86.43 sec]: learning rate : 0.000125 loss : 0.304307 +[08:35:26.219] iteration 55000 [134.27 sec]: learning rate : 0.000125 loss : 0.372394 +[08:36:13.729] iteration 55100 [181.78 sec]: learning rate : 0.000125 loss : 0.257582 +[08:37:01.202] iteration 55200 [229.26 sec]: learning rate : 0.000125 loss : 0.212756 +[08:37:46.923] Epoch 95 Evaluation: +[08:40:41.101] average MSE: 0.08590300497992512 average PSNR: 23.67308991751771 average SSIM: 0.6185540605220141 +[08:40:43.206] iteration 55300 [2.08 sec]: learning rate : 0.000125 loss : 0.291762 +[08:41:30.797] iteration 55400 [49.67 sec]: learning rate : 0.000125 loss : 0.295744 +[08:42:18.295] iteration 55500 [97.17 sec]: learning rate : 0.000125 loss : 0.315875 +[08:43:05.930] iteration 55600 [144.81 sec]: learning rate : 0.000125 loss : 0.387522 +[08:43:53.884] iteration 55700 [192.76 sec]: learning rate : 0.000125 loss : 0.317970 +[08:44:41.548] iteration 55800 [240.42 sec]: learning rate : 0.000125 loss : 0.223397 +[08:45:15.811] Epoch 96 Evaluation: +[08:48:08.473] average MSE: 0.09807114384532434 average PSNR: 23.08734234452139 average SSIM: 0.610156953056075 +[08:48:21.970] iteration 55900 [13.47 sec]: learning rate : 0.000125 loss : 0.270533 +[08:49:09.378] iteration 56000 [60.88 sec]: learning rate : 0.000125 loss : 0.340721 +[08:49:57.385] iteration 56100 [108.89 sec]: learning rate : 0.000125 loss : 0.292838 +[08:50:44.904] iteration 56200 [156.40 sec]: learning rate : 0.000125 loss : 0.337794 +[08:51:32.363] iteration 56300 [203.86 sec]: learning rate : 0.000125 loss : 0.302097 +[08:52:19.903] iteration 56400 [251.40 sec]: learning rate : 0.000125 loss : 0.280625 +[08:52:43.025] Epoch 97 Evaluation: +[08:55:35.972] average MSE: 0.09937429393040308 average PSNR: 23.029533344411167 average SSIM: 0.6127660746713997 +[08:56:01.067] iteration 56500 [25.07 sec]: learning rate : 0.000125 loss : 0.282158 +[08:56:48.632] iteration 56600 [72.64 sec]: learning rate : 0.000125 loss : 0.353534 +[08:57:36.280] iteration 56700 [120.29 sec]: learning rate : 0.000125 loss : 0.342118 +[08:58:23.929] iteration 56800 [167.93 sec]: learning rate : 0.000125 loss : 4.327261 +[08:59:11.954] iteration 56900 [215.96 sec]: learning rate : 0.000125 loss : 0.350034 +[08:59:59.607] iteration 57000 [263.61 sec]: learning rate : 0.000125 loss : 0.286065 +[09:00:11.018] Epoch 98 Evaluation: +[09:03:04.748] average MSE: 0.11052964017831439 average PSNR: 22.565985995978284 average SSIM: 0.6027809996414094 +[09:03:41.217] iteration 57100 [36.45 sec]: learning rate : 0.000125 loss : 0.282907 +[09:04:28.765] iteration 57200 [84.00 sec]: learning rate : 0.000125 loss : 0.319556 +[09:05:17.018] iteration 57300 [132.25 sec]: learning rate : 0.000125 loss : 0.371235 +[09:06:04.964] iteration 57400 [180.19 sec]: learning rate : 0.000125 loss : 0.297569 +[09:06:52.658] iteration 57500 [227.89 sec]: learning rate : 0.000125 loss : 0.329394 +[09:07:40.258] iteration 57600 [275.49 sec]: learning rate : 0.000125 loss : 1.688811 +[09:07:40.293] Epoch 99 Evaluation: +[09:10:37.112] average MSE: 0.08696376343547756 average PSNR: 23.616811314729098 average SSIM: 0.6129533318038273 +[09:11:25.549] iteration 57700 [48.41 sec]: learning rate : 0.000125 loss : 1.194367 +[09:12:13.537] iteration 57800 [96.40 sec]: learning rate : 0.000125 loss : 0.319365 +[09:13:01.161] iteration 57900 [144.03 sec]: learning rate : 0.000125 loss : 0.327277 +[09:13:48.843] iteration 58000 [191.71 sec]: learning rate : 0.000125 loss : 0.265132 +[09:14:36.671] iteration 58100 [239.54 sec]: learning rate : 0.000125 loss : 0.364323 +[09:15:12.805] Epoch 100 Evaluation: +[09:18:09.465] average MSE: 0.07544259678613785 average PSNR: 24.238559006003253 average SSIM: 0.6198320265675282 +[09:18:21.278] iteration 58200 [11.79 sec]: learning rate : 0.000125 loss : 0.321396 +[09:19:08.680] iteration 58300 [59.19 sec]: learning rate : 0.000125 loss : 0.278239 +[09:19:56.601] iteration 58400 [107.12 sec]: learning rate : 0.000125 loss : 0.380124 +[09:20:44.277] iteration 58500 [154.79 sec]: learning rate : 0.000125 loss : 0.310830 +[09:21:31.847] iteration 58600 [202.36 sec]: learning rate : 0.000125 loss : 0.253148 +[09:22:19.586] iteration 58700 [250.10 sec]: learning rate : 0.000125 loss : 0.354126 +[09:22:44.531] Epoch 101 Evaluation: +[09:25:46.983] average MSE: 0.09497978862008996 average PSNR: 23.233250915945085 average SSIM: 0.6029976009397463 +[09:26:10.042] iteration 58800 [23.04 sec]: learning rate : 0.000125 loss : 0.358283 +[09:26:57.677] iteration 58900 [70.67 sec]: learning rate : 0.000125 loss : 0.243855 +[09:27:45.260] iteration 59000 [118.25 sec]: learning rate : 0.000125 loss : 0.302282 +[09:28:33.431] iteration 59100 [166.42 sec]: learning rate : 0.000125 loss : 0.283549 +[09:29:21.941] iteration 59200 [214.93 sec]: learning rate : 0.000125 loss : 0.396509 +[09:30:09.500] iteration 59300 [262.49 sec]: learning rate : 0.000125 loss : 0.298477 +[09:30:22.792] Epoch 102 Evaluation: +[09:33:16.528] average MSE: 0.060627981442889556 average PSNR: 25.208240486584245 average SSIM: 0.6458002516559391 +[09:33:50.935] iteration 59400 [34.38 sec]: learning rate : 0.000125 loss : 0.311858 +[09:34:39.139] iteration 59500 [82.59 sec]: learning rate : 0.000125 loss : 0.362947 +[09:35:27.005] iteration 59600 [130.45 sec]: learning rate : 0.000125 loss : 0.322158 +[09:36:14.630] iteration 59700 [178.08 sec]: learning rate : 0.000125 loss : 0.277777 +[09:37:02.317] iteration 59800 [225.77 sec]: learning rate : 0.000125 loss : 0.256246 +[09:37:50.221] iteration 59900 [273.67 sec]: learning rate : 0.000125 loss : 0.345450 +[09:37:52.135] Epoch 103 Evaluation: +[09:40:43.888] average MSE: 0.09625975836085549 average PSNR: 23.171118950265896 average SSIM: 0.6088503189996699 +[09:41:30.267] iteration 60000 [46.35 sec]: learning rate : 0.000031 loss : 0.361567 +[09:41:30.431] save model to model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/iter_60000.pth +[09:42:18.140] iteration 60100 [94.23 sec]: learning rate : 0.000063 loss : 0.257053 +[09:43:05.752] iteration 60200 [141.84 sec]: learning rate : 0.000063 loss : 0.349397 +[09:43:53.534] iteration 60300 [189.62 sec]: learning rate : 0.000063 loss : 0.304786 +[09:44:41.303] iteration 60400 [237.39 sec]: learning rate : 0.000063 loss : 0.279973 +[09:45:19.875] Epoch 104 Evaluation: +[09:48:17.236] average MSE: 0.10681284349129933 average PSNR: 22.715076837031358 average SSIM: 0.6062972903701932 +[09:48:26.912] iteration 60500 [9.65 sec]: learning rate : 0.000063 loss : 0.259085 +[09:49:14.475] iteration 60600 [57.22 sec]: learning rate : 0.000063 loss : 0.306983 +[09:50:01.989] iteration 60700 [104.75 sec]: learning rate : 0.000063 loss : 0.351854 +[09:50:49.782] iteration 60800 [152.52 sec]: learning rate : 0.000063 loss : 0.262164 +[09:51:37.329] iteration 60900 [200.07 sec]: learning rate : 0.000063 loss : 0.269726 +[09:52:24.788] iteration 61000 [247.53 sec]: learning rate : 0.000063 loss : 0.317139 +[09:52:51.447] Epoch 105 Evaluation: +[09:55:43.471] average MSE: 0.09886353071960734 average PSNR: 23.05507059788808 average SSIM: 0.6016795694369919 +[09:56:04.516] iteration 61100 [21.02 sec]: learning rate : 0.000063 loss : 0.323412 +[09:56:52.940] iteration 61200 [69.45 sec]: learning rate : 0.000063 loss : 0.289036 +[09:57:40.557] iteration 61300 [117.06 sec]: learning rate : 0.000063 loss : 0.335007 +[09:58:28.325] iteration 61400 [164.83 sec]: learning rate : 0.000063 loss : 0.307172 +[09:59:16.524] iteration 61500 [213.03 sec]: learning rate : 0.000063 loss : 0.296399 +[10:00:03.982] iteration 61600 [260.49 sec]: learning rate : 0.000063 loss : 0.336805 +[10:00:19.164] Epoch 106 Evaluation: +[10:03:11.772] average MSE: 0.12247087436903452 average PSNR: 22.12345288220932 average SSIM: 0.6094734886712776 +[10:03:44.406] iteration 61700 [32.61 sec]: learning rate : 0.000063 loss : 0.288414 +[10:04:32.476] iteration 61800 [80.68 sec]: learning rate : 0.000063 loss : 0.279915 +[10:05:20.508] iteration 61900 [128.71 sec]: learning rate : 0.000063 loss : 0.305920 +[10:06:08.255] iteration 62000 [176.46 sec]: learning rate : 0.000063 loss : 0.281201 +[10:06:55.926] iteration 62100 [224.13 sec]: learning rate : 0.000063 loss : 0.389238 +[10:07:43.997] iteration 62200 [272.20 sec]: learning rate : 0.000063 loss : 0.341515 +[10:07:47.813] Epoch 107 Evaluation: +[10:10:49.512] average MSE: 0.1016942605616629 average PSNR: 22.929813110711056 average SSIM: 0.604102589793084 +[10:11:34.017] iteration 62300 [44.48 sec]: learning rate : 0.000063 loss : 0.317858 +[10:12:21.400] iteration 62400 [91.86 sec]: learning rate : 0.000063 loss : 0.332700 +[10:13:09.458] iteration 62500 [139.92 sec]: learning rate : 0.000063 loss : 0.266143 +[10:13:57.544] iteration 62600 [188.01 sec]: learning rate : 0.000063 loss : 0.330043 +[10:14:45.048] iteration 62700 [235.51 sec]: learning rate : 0.000063 loss : 0.288348 +[10:15:24.981] Epoch 108 Evaluation: +[10:18:18.919] average MSE: 0.11069739062415883 average PSNR: 22.559716152296946 average SSIM: 0.6116731128259099 +[10:18:26.755] iteration 62800 [7.81 sec]: learning rate : 0.000063 loss : 0.358893 +[10:19:15.175] iteration 62900 [56.23 sec]: learning rate : 0.000063 loss : 0.293786 +[10:20:02.802] iteration 63000 [103.86 sec]: learning rate : 0.000063 loss : 0.316201 +[10:20:50.502] iteration 63100 [151.56 sec]: learning rate : 0.000063 loss : 0.227755 +[10:21:38.122] iteration 63200 [199.18 sec]: learning rate : 0.000063 loss : 0.374065 +[10:22:26.100] iteration 63300 [247.16 sec]: learning rate : 0.000063 loss : 0.263402 +[10:22:54.572] Epoch 109 Evaluation: +[10:25:54.466] average MSE: 0.08834525734109365 average PSNR: 23.54725661377232 average SSIM: 0.6140166375374934 +[10:26:13.893] iteration 63400 [19.41 sec]: learning rate : 0.000063 loss : 0.337203 +[10:27:01.482] iteration 63500 [66.99 sec]: learning rate : 0.000063 loss : 0.356029 +[10:27:49.159] iteration 63600 [114.67 sec]: learning rate : 0.000063 loss : 0.307767 +[10:28:36.832] iteration 63700 [162.34 sec]: learning rate : 0.000063 loss : 0.355306 +[10:29:25.060] iteration 63800 [210.57 sec]: learning rate : 0.000063 loss : 0.282091 +[10:30:13.110] iteration 63900 [258.62 sec]: learning rate : 0.000063 loss : 0.415764 +[10:30:30.258] Epoch 110 Evaluation: +[10:33:23.304] average MSE: 0.09445954400238514 average PSNR: 23.251752030271017 average SSIM: 0.6105287603195029 +[10:33:54.107] iteration 64000 [30.78 sec]: learning rate : 0.000063 loss : 0.409021 +[10:34:41.632] iteration 64100 [78.31 sec]: learning rate : 0.000063 loss : 0.366679 +[10:35:30.395] iteration 64200 [127.07 sec]: learning rate : 0.000063 loss : 0.325144 +[10:36:17.925] iteration 64300 [174.60 sec]: learning rate : 0.000063 loss : 0.305743 +[10:37:05.688] iteration 64400 [222.36 sec]: learning rate : 0.000063 loss : 0.377158 +[10:37:53.453] iteration 64500 [270.13 sec]: learning rate : 0.000063 loss : 0.320814 +[10:37:59.148] Epoch 111 Evaluation: +[10:40:52.572] average MSE: 0.09969784551933909 average PSNR: 23.016883387050516 average SSIM: 0.6054063374234955 +[10:41:35.197] iteration 64600 [42.60 sec]: learning rate : 0.000063 loss : 0.349743 +[10:42:22.809] iteration 64700 [90.21 sec]: learning rate : 0.000063 loss : 0.317016 +[10:43:10.335] iteration 64800 [137.74 sec]: learning rate : 0.000063 loss : 0.309142 +[10:43:57.807] iteration 64900 [185.21 sec]: learning rate : 0.000063 loss : 0.310573 +[10:44:45.427] iteration 65000 [232.83 sec]: learning rate : 0.000063 loss : 0.323788 +[10:45:27.168] Epoch 112 Evaluation: +[10:48:25.196] average MSE: 0.10177132718360424 average PSNR: 22.926163672722744 average SSIM: 0.6089950636281152 +[10:48:31.079] iteration 65100 [5.86 sec]: learning rate : 0.000063 loss : 0.354456 +[10:49:19.146] iteration 65200 [53.93 sec]: learning rate : 0.000063 loss : 0.347455 +[10:50:06.681] iteration 65300 [101.46 sec]: learning rate : 0.000063 loss : 0.240044 +[10:50:54.096] iteration 65400 [148.88 sec]: learning rate : 0.000063 loss : 0.422378 +[10:51:41.670] iteration 65500 [196.45 sec]: learning rate : 0.000063 loss : 0.311198 +[10:52:29.697] iteration 65600 [244.48 sec]: learning rate : 0.000063 loss : 0.420354 +[10:53:00.060] Epoch 113 Evaluation: +[10:55:52.448] average MSE: 0.11907467756674617 average PSNR: 22.246403963572952 average SSIM: 0.6190447262832758 +[10:56:09.788] iteration 65700 [17.32 sec]: learning rate : 0.000063 loss : 0.309417 +[10:56:57.523] iteration 65800 [65.05 sec]: learning rate : 0.000063 loss : 0.306069 +[10:57:45.744] iteration 65900 [113.27 sec]: learning rate : 0.000063 loss : 0.322079 +[10:58:33.174] iteration 66000 [160.70 sec]: learning rate : 0.000063 loss : 0.356669 +[10:59:20.766] iteration 66100 [208.30 sec]: learning rate : 0.000063 loss : 0.407172 +[11:00:08.508] iteration 66200 [256.04 sec]: learning rate : 0.000063 loss : 0.302311 +[11:00:27.651] Epoch 114 Evaluation: +[11:03:22.871] average MSE: 0.12870833454779704 average PSNR: 21.91566619396799 average SSIM: 0.6161016523258641 +[11:03:51.577] iteration 66300 [28.68 sec]: learning rate : 0.000063 loss : 0.412821 +[11:04:39.389] iteration 66400 [76.50 sec]: learning rate : 0.000063 loss : 0.286150 +[11:05:26.940] iteration 66500 [124.05 sec]: learning rate : 0.000063 loss : 0.414422 +[11:06:14.874] iteration 66600 [171.98 sec]: learning rate : 0.000063 loss : 1.825459 +[11:07:02.572] iteration 66700 [219.68 sec]: learning rate : 0.000063 loss : 0.340342 +[11:07:50.070] iteration 66800 [267.18 sec]: learning rate : 0.000063 loss : 0.283292 +[11:07:57.696] Epoch 115 Evaluation: +[11:10:51.916] average MSE: 0.10886272688108664 average PSNR: 22.63274891087109 average SSIM: 0.610730732285194 +[11:11:32.151] iteration 66900 [40.21 sec]: learning rate : 0.000063 loss : 0.325049 +[11:12:20.125] iteration 67000 [88.19 sec]: learning rate : 0.000063 loss : 0.411314 +[11:13:07.650] iteration 67100 [135.71 sec]: learning rate : 0.000063 loss : 0.318720 +[11:13:55.241] iteration 67200 [183.30 sec]: learning rate : 0.000063 loss : 0.273800 +[11:14:43.111] iteration 67300 [231.17 sec]: learning rate : 0.000063 loss : 0.373322 +[11:15:26.898] Epoch 116 Evaluation: +[11:18:19.686] average MSE: 0.1249494674421785 average PSNR: 22.04154457513515 average SSIM: 0.616077138474014 +[11:18:23.680] iteration 67400 [3.97 sec]: learning rate : 0.000063 loss : 0.324938 +[11:19:11.295] iteration 67500 [51.59 sec]: learning rate : 0.000063 loss : 0.249638 +[11:19:58.748] iteration 67600 [99.04 sec]: learning rate : 0.000063 loss : 3.234504 +[11:20:46.752] iteration 67700 [147.04 sec]: learning rate : 0.000063 loss : 0.353459 +[11:21:34.326] iteration 67800 [194.62 sec]: learning rate : 0.000063 loss : 0.269249 +[11:22:21.806] iteration 67900 [242.10 sec]: learning rate : 0.000063 loss : 0.359270 +[11:22:54.199] Epoch 117 Evaluation: +[11:25:57.061] average MSE: 0.12089254655030306 average PSNR: 22.180113437185902 average SSIM: 0.6146454311705843 +[11:26:12.439] iteration 68000 [15.36 sec]: learning rate : 0.000063 loss : 0.306661 +[11:27:00.107] iteration 68100 [63.02 sec]: learning rate : 0.000063 loss : 0.344765 +[11:27:47.639] iteration 68200 [110.56 sec]: learning rate : 0.000063 loss : 0.399376 +[11:28:35.163] iteration 68300 [158.08 sec]: learning rate : 0.000063 loss : 0.294110 +[11:29:22.649] iteration 68400 [205.56 sec]: learning rate : 0.000063 loss : 0.282846 +[11:30:10.722] iteration 68500 [253.64 sec]: learning rate : 0.000063 loss : 0.343610 +[11:30:31.672] Epoch 118 Evaluation: +[11:33:30.083] average MSE: 0.10620344147245805 average PSNR: 22.73813218016753 average SSIM: 0.6115104583095274 +[11:33:57.047] iteration 68600 [26.94 sec]: learning rate : 0.000063 loss : 0.373108 +[11:34:44.500] iteration 68700 [74.39 sec]: learning rate : 0.000063 loss : 0.311419 +[11:35:32.102] iteration 68800 [122.00 sec]: learning rate : 0.000063 loss : 0.293180 +[11:36:20.219] iteration 68900 [170.11 sec]: learning rate : 0.000063 loss : 0.408493 +[11:37:08.036] iteration 69000 [217.93 sec]: learning rate : 0.000063 loss : 0.232749 +[11:37:55.596] iteration 69100 [265.49 sec]: learning rate : 0.000063 loss : 0.305517 +[11:38:05.096] Epoch 119 Evaluation: +[11:40:59.338] average MSE: 0.11845530102589426 average PSNR: 22.266335360305835 average SSIM: 0.6191562091866065 +[11:41:37.470] iteration 69200 [38.11 sec]: learning rate : 0.000063 loss : 0.428531 +[11:42:26.573] iteration 69300 [87.21 sec]: learning rate : 0.000063 loss : 0.334771 +[11:43:14.181] iteration 69400 [134.82 sec]: learning rate : 0.000063 loss : 0.320262 +[11:44:01.735] iteration 69500 [182.38 sec]: learning rate : 0.000063 loss : 0.220471 +[11:44:49.326] iteration 69600 [229.97 sec]: learning rate : 0.000063 loss : 0.241015 +[11:45:34.979] Epoch 120 Evaluation: +[11:48:27.820] average MSE: 0.12982545743257107 average PSNR: 21.87428071020307 average SSIM: 0.6143782556508071 +[11:48:29.922] iteration 69700 [2.08 sec]: learning rate : 0.000063 loss : 0.565587 +[11:49:17.318] iteration 69800 [49.47 sec]: learning rate : 0.000063 loss : 0.306525 +[11:50:04.886] iteration 69900 [97.04 sec]: learning rate : 0.000063 loss : 0.287040 +[11:50:52.395] iteration 70000 [144.55 sec]: learning rate : 0.000063 loss : 0.383737 +[11:51:40.200] iteration 70100 [192.36 sec]: learning rate : 0.000063 loss : 0.306720 +[11:52:27.751] iteration 70200 [239.91 sec]: learning rate : 0.000063 loss : 0.196406 +[11:53:01.891] Epoch 121 Evaluation: +[11:55:53.980] average MSE: 0.09737788480483878 average PSNR: 23.118873074321392 average SSIM: 0.6205190427734749 +[11:56:07.626] iteration 70300 [13.62 sec]: learning rate : 0.000063 loss : 0.298854 +[11:56:55.207] iteration 70400 [61.20 sec]: learning rate : 0.000063 loss : 0.316866 +[11:57:42.890] iteration 70500 [108.89 sec]: learning rate : 0.000063 loss : 0.336514 +[11:58:30.457] iteration 70600 [156.45 sec]: learning rate : 0.000063 loss : 0.376976 +[11:59:18.699] iteration 70700 [204.70 sec]: learning rate : 0.000063 loss : 0.283138 +[12:00:06.966] iteration 70800 [252.96 sec]: learning rate : 0.000063 loss : 0.322984 +[12:00:30.054] Epoch 122 Evaluation: +[12:03:26.572] average MSE: 0.12082128493691023 average PSNR: 22.181690694200753 average SSIM: 0.6168391031702314 +[12:03:51.403] iteration 70900 [24.81 sec]: learning rate : 0.000063 loss : 0.261402 +[12:04:39.504] iteration 71000 [72.93 sec]: learning rate : 0.000063 loss : 0.314979 +[12:05:27.079] iteration 71100 [120.48 sec]: learning rate : 0.000063 loss : 0.343453 +[12:06:14.718] iteration 71200 [168.12 sec]: learning rate : 0.000063 loss : 0.270551 +[12:07:02.356] iteration 71300 [215.76 sec]: learning rate : 0.000063 loss : 0.322171 +[12:07:50.001] iteration 71400 [263.41 sec]: learning rate : 0.000063 loss : 0.269777 +[12:08:01.384] Epoch 123 Evaluation: +[12:11:00.424] average MSE: 0.11240719841417482 average PSNR: 22.493393333167557 average SSIM: 0.611336411469287 +[12:11:36.648] iteration 71500 [36.20 sec]: learning rate : 0.000063 loss : 0.292927 +[12:12:24.484] iteration 71600 [84.04 sec]: learning rate : 0.000063 loss : 0.313804 +[12:13:12.051] iteration 71700 [131.60 sec]: learning rate : 0.000063 loss : 0.352796 +[12:13:59.459] iteration 71800 [179.01 sec]: learning rate : 0.000063 loss : 0.310313 +[12:14:47.074] iteration 71900 [226.63 sec]: learning rate : 0.000063 loss : 0.372332 +[12:15:34.496] iteration 72000 [274.05 sec]: learning rate : 0.000063 loss : 0.354586 +[12:15:34.532] Epoch 124 Evaluation: +[12:18:27.623] average MSE: 0.09000988547040023 average PSNR: 23.46150128210785 average SSIM: 0.6168020912808289 +[12:19:15.319] iteration 72100 [47.67 sec]: learning rate : 0.000063 loss : 0.263327 +[12:20:02.760] iteration 72200 [95.11 sec]: learning rate : 0.000063 loss : 0.301189 +[12:20:50.138] iteration 72300 [142.49 sec]: learning rate : 0.000063 loss : 0.330635 +[12:21:38.129] iteration 72400 [190.48 sec]: learning rate : 0.000063 loss : 0.272265 +[12:22:25.905] iteration 72500 [238.26 sec]: learning rate : 0.000063 loss : 0.362063 +[12:23:02.141] Epoch 125 Evaluation: +[12:25:59.811] average MSE: 0.10484192033881388 average PSNR: 22.79529767944722 average SSIM: 0.6151785745700953 +[12:26:11.403] iteration 72600 [11.57 sec]: learning rate : 0.000063 loss : 0.388979 +[12:26:59.328] iteration 72700 [59.49 sec]: learning rate : 0.000063 loss : 0.287190 +[12:27:47.168] iteration 72800 [107.35 sec]: learning rate : 0.000063 loss : 0.345195 +[12:28:34.881] iteration 72900 [155.05 sec]: learning rate : 0.000063 loss : 0.317611 +[12:29:22.641] iteration 73000 [202.81 sec]: learning rate : 0.000063 loss : 0.266804 +[12:30:10.428] iteration 73100 [250.61 sec]: learning rate : 0.000063 loss : 0.350636 +[12:30:35.206] Epoch 126 Evaluation: +[12:33:26.885] average MSE: 0.11336384435103536 average PSNR: 22.454653160916777 average SSIM: 0.6074326053895069 +[12:33:49.839] iteration 73200 [22.93 sec]: learning rate : 0.000063 loss : 0.387726 +[12:34:37.401] iteration 73300 [70.49 sec]: learning rate : 0.000063 loss : 0.265283 +[12:35:24.940] iteration 73400 [118.03 sec]: learning rate : 0.000063 loss : 0.306155 +[12:36:13.396] iteration 73500 [166.49 sec]: learning rate : 0.000063 loss : 0.332787 +[12:37:00.928] iteration 73600 [214.02 sec]: learning rate : 0.000063 loss : 0.374188 +[12:37:48.468] iteration 73700 [261.56 sec]: learning rate : 0.000063 loss : 0.330067 +[12:38:01.803] Epoch 127 Evaluation: +[12:40:52.993] average MSE: 0.12930653094009703 average PSNR: 21.893315579783973 average SSIM: 0.6201879099019688 +[12:41:27.654] iteration 73800 [34.64 sec]: learning rate : 0.000063 loss : 0.330052 +[12:42:15.830] iteration 73900 [82.81 sec]: learning rate : 0.000063 loss : 0.390841 +[12:43:03.294] iteration 74000 [130.28 sec]: learning rate : 0.000063 loss : 0.373412 +[12:43:51.358] iteration 74100 [178.34 sec]: learning rate : 0.000063 loss : 0.325681 +[12:44:38.949] iteration 74200 [225.93 sec]: learning rate : 0.000063 loss : 0.231887 +[12:45:26.966] iteration 74300 [273.95 sec]: learning rate : 0.000063 loss : 0.337275 +[12:45:28.874] Epoch 128 Evaluation: +[12:48:22.807] average MSE: 0.11363266688110495 average PSNR: 22.446413862884807 average SSIM: 0.6114981570753778 +[12:49:08.741] iteration 74400 [45.91 sec]: learning rate : 0.000063 loss : 0.351757 +[12:49:56.144] iteration 74500 [93.31 sec]: learning rate : 0.000063 loss : 0.267737 +[12:50:43.753] iteration 74600 [140.92 sec]: learning rate : 0.000063 loss : 0.349348 +[12:51:31.392] iteration 74700 [188.56 sec]: learning rate : 0.000063 loss : 0.309382 +[12:52:18.920] iteration 74800 [236.09 sec]: learning rate : 0.000063 loss : 0.301728 +[12:52:56.930] Epoch 129 Evaluation: +[12:55:48.417] average MSE: 0.11135096035798163 average PSNR: 22.53367326508478 average SSIM: 0.6144920261781043 +[12:55:58.150] iteration 74900 [9.71 sec]: learning rate : 0.000063 loss : 0.232573 +[12:56:45.800] iteration 75000 [57.36 sec]: learning rate : 0.000063 loss : 0.339513 +[12:57:33.204] iteration 75100 [104.76 sec]: learning rate : 0.000063 loss : 0.345874 +[12:58:20.681] iteration 75200 [152.24 sec]: learning rate : 0.000063 loss : 0.279290 +[12:59:08.322] iteration 75300 [199.88 sec]: learning rate : 0.000063 loss : 0.275034 +[12:59:55.861] iteration 75400 [247.42 sec]: learning rate : 0.000063 loss : 0.326637 +[13:00:23.038] Epoch 130 Evaluation: +[13:03:16.697] average MSE: 0.13299823107160838 average PSNR: 21.77517620452468 average SSIM: 0.6163402125686818 +[13:03:37.764] iteration 75500 [21.04 sec]: learning rate : 0.000063 loss : 0.307095 +[13:04:25.310] iteration 75600 [68.59 sec]: learning rate : 0.000063 loss : 0.281475 +[13:05:12.694] iteration 75700 [115.97 sec]: learning rate : 0.000063 loss : 0.349183 +[13:06:00.142] iteration 75800 [163.42 sec]: learning rate : 0.000063 loss : 0.341997 +[13:06:47.892] iteration 75900 [211.17 sec]: learning rate : 0.000063 loss : 0.282275 +[13:07:35.526] iteration 76000 [258.80 sec]: learning rate : 0.000063 loss : 0.295753 +[13:07:50.741] Epoch 131 Evaluation: +[13:10:48.363] average MSE: 0.11304304447737416 average PSNR: 22.469149030758622 average SSIM: 0.6142647495601448 +[13:11:20.905] iteration 76100 [32.52 sec]: learning rate : 0.000063 loss : 0.293848 +[13:12:08.284] iteration 76200 [79.92 sec]: learning rate : 0.000063 loss : 0.321778 +[13:12:56.360] iteration 76300 [127.97 sec]: learning rate : 0.000063 loss : 0.325856 +[13:13:43.803] iteration 76400 [175.42 sec]: learning rate : 0.000063 loss : 0.306969 +[13:14:31.291] iteration 76500 [222.90 sec]: learning rate : 0.000063 loss : 0.381406 +[13:15:18.769] iteration 76600 [270.38 sec]: learning rate : 0.000063 loss : 0.317523 +[13:15:22.564] Epoch 132 Evaluation: +[13:18:13.994] average MSE: 0.12347725343029414 average PSNR: 22.0884644001956 average SSIM: 0.6180866962480244 +[13:18:58.379] iteration 76700 [44.36 sec]: learning rate : 0.000063 loss : 0.236152 +[13:19:45.718] iteration 76800 [91.70 sec]: learning rate : 0.000063 loss : 0.381493 +[13:20:33.172] iteration 76900 [139.15 sec]: learning rate : 0.000063 loss : 0.293525 +[13:21:20.764] iteration 77000 [186.75 sec]: learning rate : 0.000063 loss : 0.313423 +[13:22:08.294] iteration 77100 [234.28 sec]: learning rate : 0.000063 loss : 0.274342 +[13:22:48.167] Epoch 133 Evaluation: +[13:25:39.437] average MSE: 0.1248370208125125 average PSNR: 22.04790370543186 average SSIM: 0.6246137511851071 +[13:25:47.207] iteration 77200 [7.75 sec]: learning rate : 0.000063 loss : 0.316305 +[13:26:34.560] iteration 77300 [55.10 sec]: learning rate : 0.000063 loss : 0.329570 +[13:27:22.078] iteration 77400 [102.62 sec]: learning rate : 0.000063 loss : 0.327109 +[13:28:09.542] iteration 77500 [150.08 sec]: learning rate : 0.000063 loss : 0.232287 +[13:28:56.946] iteration 77600 [197.49 sec]: learning rate : 0.000063 loss : 0.356386 +[13:29:44.448] iteration 77700 [244.99 sec]: learning rate : 0.000063 loss : 0.281697 +[13:30:13.090] Epoch 134 Evaluation: +[13:33:04.872] average MSE: 0.1392581943598641 average PSNR: 21.584772400215204 average SSIM: 0.6173155054069004 +[13:33:24.213] iteration 77800 [19.32 sec]: learning rate : 0.000063 loss : 0.295538 +[13:34:11.680] iteration 77900 [66.78 sec]: learning rate : 0.000063 loss : 0.320489 +[13:34:59.266] iteration 78000 [114.37 sec]: learning rate : 0.000063 loss : 0.318057 +[13:35:46.824] iteration 78100 [161.93 sec]: learning rate : 0.000063 loss : 0.360866 +[13:36:34.805] iteration 78200 [209.93 sec]: learning rate : 0.000063 loss : 0.281904 +[13:37:22.752] iteration 78300 [257.86 sec]: learning rate : 0.000063 loss : 0.382951 +[13:37:39.885] Epoch 135 Evaluation: +[13:40:34.751] average MSE: 0.11838152679559308 average PSNR: 22.27108737710377 average SSIM: 0.61571836805463 +[13:41:05.339] iteration 78400 [30.56 sec]: learning rate : 0.000063 loss : 0.366140 +[13:41:52.945] iteration 78500 [78.17 sec]: learning rate : 0.000063 loss : 0.336050 +[13:42:40.788] iteration 78600 [126.01 sec]: learning rate : 0.000063 loss : 0.366177 +[13:43:28.285] iteration 78700 [173.51 sec]: learning rate : 0.000063 loss : 0.329068 +[13:44:15.967] iteration 78800 [221.19 sec]: learning rate : 0.000063 loss : 0.349778 +[13:45:03.552] iteration 78900 [268.78 sec]: learning rate : 0.000063 loss : 0.332952 +[13:45:09.253] Epoch 136 Evaluation: +[13:48:03.391] average MSE: 0.12291192183477195 average PSNR: 22.109760770851825 average SSIM: 0.6167942912919719 +[13:48:45.659] iteration 79000 [42.25 sec]: learning rate : 0.000063 loss : 0.320529 +[13:49:33.322] iteration 79100 [89.91 sec]: learning rate : 0.000063 loss : 0.328803 +[13:50:20.960] iteration 79200 [137.55 sec]: learning rate : 0.000063 loss : 0.310951 +[13:51:08.516] iteration 79300 [185.10 sec]: learning rate : 0.000063 loss : 0.313549 +[13:51:56.224] iteration 79400 [232.81 sec]: learning rate : 0.000063 loss : 0.313797 +[13:52:38.068] Epoch 137 Evaluation: +[13:55:37.027] average MSE: 0.12726149239719906 average PSNR: 21.961597249191495 average SSIM: 0.615264753362259 +[13:55:42.924] iteration 79500 [5.87 sec]: learning rate : 0.000063 loss : 0.334924 +[13:56:30.444] iteration 79600 [53.39 sec]: learning rate : 0.000063 loss : 0.327406 +[13:57:18.356] iteration 79700 [101.31 sec]: learning rate : 0.000063 loss : 0.299317 +[13:58:05.733] iteration 79800 [148.68 sec]: learning rate : 0.000063 loss : 0.388615 +[13:58:53.265] iteration 79900 [196.22 sec]: learning rate : 0.000063 loss : 0.292090 +[13:59:40.793] iteration 80000 [243.74 sec]: learning rate : 0.000016 loss : 0.359692 +[13:59:40.952] save model to model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/iter_80000.pth +[14:00:11.288] Epoch 138 Evaluation: +[14:03:02.871] average MSE: 0.11796441481105306 average PSNR: 22.28604167372905 average SSIM: 0.6131603821153403 +[14:03:20.117] iteration 80100 [17.22 sec]: learning rate : 0.000031 loss : 0.283178 +[14:04:07.633] iteration 80200 [64.74 sec]: learning rate : 0.000031 loss : 0.335037 +[14:04:55.113] iteration 80300 [112.22 sec]: learning rate : 0.000031 loss : 0.279679 +[14:05:43.199] iteration 80400 [160.30 sec]: learning rate : 0.000031 loss : 0.368960 +[14:06:30.791] iteration 80500 [207.90 sec]: learning rate : 0.000031 loss : 0.350660 +[14:07:18.680] iteration 80600 [255.79 sec]: learning rate : 0.000031 loss : 0.283323 +[14:07:37.899] Epoch 139 Evaluation: +[14:10:39.272] average MSE: 0.11671934569526292 average PSNR: 22.33045045163173 average SSIM: 0.6114313654115306 +[14:11:08.421] iteration 80700 [29.12 sec]: learning rate : 0.000031 loss : 0.383758 +[14:11:55.912] iteration 80800 [76.62 sec]: learning rate : 0.000031 loss : 0.277249 +[14:12:43.292] iteration 80900 [124.00 sec]: learning rate : 0.000031 loss : 0.352612 +[14:13:30.713] iteration 81000 [171.42 sec]: learning rate : 0.000031 loss : 0.369684 +[14:14:18.245] iteration 81100 [218.97 sec]: learning rate : 0.000031 loss : 0.328310 +[14:15:05.717] iteration 81200 [266.42 sec]: learning rate : 0.000031 loss : 0.299322 +[14:15:13.300] Epoch 140 Evaluation: +[14:18:03.682] average MSE: 0.12369261900966193 average PSNR: 22.081254645110832 average SSIM: 0.6121031673948393 +[14:18:43.780] iteration 81300 [40.07 sec]: learning rate : 0.000031 loss : 0.332860 +[14:19:31.187] iteration 81400 [87.48 sec]: learning rate : 0.000031 loss : 0.331923 +[14:20:18.516] iteration 81500 [134.81 sec]: learning rate : 0.000031 loss : 0.303918 +[14:21:05.967] iteration 81600 [182.26 sec]: learning rate : 0.000031 loss : 0.240413 +[14:21:53.324] iteration 81700 [229.62 sec]: learning rate : 0.000031 loss : 0.328518 +[14:22:37.063] Epoch 141 Evaluation: +[14:25:31.305] average MSE: 0.13261680620610736 average PSNR: 21.788345537623762 average SSIM: 0.6181662603548063 +[14:25:35.340] iteration 81800 [4.01 sec]: learning rate : 0.000031 loss : 0.339083 +[14:26:22.841] iteration 81900 [51.51 sec]: learning rate : 0.000031 loss : 0.286442 +[14:27:10.533] iteration 82000 [99.20 sec]: learning rate : 0.000031 loss : 0.317689 +[14:27:57.986] iteration 82100 [146.66 sec]: learning rate : 0.000031 loss : 0.354600 +[14:28:45.616] iteration 82200 [194.29 sec]: learning rate : 0.000031 loss : 0.271180 +[14:29:33.162] iteration 82300 [241.83 sec]: learning rate : 0.000031 loss : 0.364892 +[14:30:05.575] Epoch 142 Evaluation: +[14:32:55.329] average MSE: 0.125756145452213 average PSNR: 22.013314601387265 average SSIM: 0.6185352528025274 +[14:33:10.690] iteration 82400 [15.34 sec]: learning rate : 0.000031 loss : 0.288942 +[14:33:57.990] iteration 82500 [62.64 sec]: learning rate : 0.000031 loss : 0.321417 +[14:34:45.467] iteration 82600 [110.11 sec]: learning rate : 0.000031 loss : 0.355998 +[14:35:32.915] iteration 82700 [157.56 sec]: learning rate : 0.000031 loss : 0.307360 +[14:36:20.269] iteration 82800 [204.92 sec]: learning rate : 0.000031 loss : 0.317074 +[14:37:07.809] iteration 82900 [252.46 sec]: learning rate : 0.000031 loss : 0.271683 +[14:37:28.784] Epoch 143 Evaluation: +[14:40:27.601] average MSE: 0.12546576866578532 average PSNR: 22.024929668733183 average SSIM: 0.621619330811804 +[14:40:54.547] iteration 83000 [26.92 sec]: learning rate : 0.000031 loss : 0.381043 +[14:41:42.076] iteration 83100 [74.45 sec]: learning rate : 0.000031 loss : 0.312107 +[14:42:29.485] iteration 83200 [121.86 sec]: learning rate : 0.000031 loss : 0.275713 +[14:43:16.883] iteration 83300 [169.26 sec]: learning rate : 0.000031 loss : 0.405223 +[14:44:04.257] iteration 83400 [216.63 sec]: learning rate : 0.000031 loss : 0.313521 +[14:44:51.706] iteration 83500 [264.08 sec]: learning rate : 0.000031 loss : 0.293804 +[14:45:01.180] Epoch 144 Evaluation: +[14:47:51.191] average MSE: 0.1308278454305795 average PSNR: 21.847762183815647 average SSIM: 0.6207356009762064 +[14:48:29.377] iteration 83600 [38.16 sec]: learning rate : 0.000031 loss : 0.373901 +[14:49:16.672] iteration 83700 [85.46 sec]: learning rate : 0.000031 loss : 0.322014 +[14:50:04.319] iteration 83800 [133.10 sec]: learning rate : 0.000031 loss : 0.328559 +[14:50:51.623] iteration 83900 [180.41 sec]: learning rate : 0.000031 loss : 0.274122 +[14:51:39.079] iteration 84000 [227.86 sec]: learning rate : 0.000031 loss : 0.240419 +[14:52:24.594] Epoch 145 Evaluation: +[14:55:14.111] average MSE: 0.11720018097242146 average PSNR: 22.31315990353149 average SSIM: 0.6161800859137673 +[14:55:16.210] iteration 84100 [2.07 sec]: learning rate : 0.000031 loss : 0.306391 +[14:56:03.513] iteration 84200 [49.38 sec]: learning rate : 0.000031 loss : 0.311236 +[14:56:50.961] iteration 84300 [96.83 sec]: learning rate : 0.000031 loss : 0.311148 +[14:57:38.365] iteration 84400 [144.23 sec]: learning rate : 0.000031 loss : 0.413537 +[14:58:25.708] iteration 84500 [191.57 sec]: learning rate : 0.000031 loss : 0.282650 +[14:59:13.390] iteration 84600 [239.26 sec]: learning rate : 0.000031 loss : 0.241964 +[14:59:47.488] Epoch 146 Evaluation: +[15:02:38.123] average MSE: 0.12201795637230173 average PSNR: 22.14019973768581 average SSIM: 0.6152693918875832 +[15:02:51.695] iteration 84700 [13.55 sec]: learning rate : 0.000031 loss : 0.264746 +[15:03:39.001] iteration 84800 [60.85 sec]: learning rate : 0.000031 loss : 0.328650 +[15:04:26.404] iteration 84900 [108.26 sec]: learning rate : 0.000031 loss : 0.278281 +[15:05:13.702] iteration 85000 [155.55 sec]: learning rate : 0.000031 loss : 0.306376 +[15:06:01.199] iteration 85100 [203.05 sec]: learning rate : 0.000031 loss : 0.329057 +[15:06:48.675] iteration 85200 [250.53 sec]: learning rate : 0.000031 loss : 0.334651 +[15:07:11.389] Epoch 147 Evaluation: +[15:10:02.504] average MSE: 0.12212941914220662 average PSNR: 22.137521557671825 average SSIM: 0.6153301359541854 +[15:10:27.370] iteration 85300 [24.84 sec]: learning rate : 0.000031 loss : 0.295554 +[15:11:15.099] iteration 85400 [72.57 sec]: learning rate : 0.000031 loss : 0.341108 +[15:12:02.795] iteration 85500 [120.27 sec]: learning rate : 0.000031 loss : 0.328893 +[15:12:50.126] iteration 85600 [167.60 sec]: learning rate : 0.000031 loss : 0.283013 +[15:13:37.686] iteration 85700 [215.16 sec]: learning rate : 0.000031 loss : 0.348995 +[15:14:25.274] iteration 85800 [262.75 sec]: learning rate : 0.000031 loss : 0.330508 +[15:14:36.684] Epoch 148 Evaluation: +[15:17:34.979] average MSE: 0.12498599356638385 average PSNR: 22.038312773319266 average SSIM: 0.6135695297391281 +[15:18:11.327] iteration 85900 [36.32 sec]: learning rate : 0.000031 loss : 0.266063 +[15:18:58.972] iteration 86000 [83.97 sec]: learning rate : 0.000031 loss : 0.293608 +[15:19:46.481] iteration 86100 [131.48 sec]: learning rate : 0.000031 loss : 0.376863 +[15:20:34.069] iteration 86200 [179.07 sec]: learning rate : 0.000031 loss : 0.301157 +[15:21:21.708] iteration 86300 [226.71 sec]: learning rate : 0.000031 loss : 0.322569 +[15:22:09.186] iteration 86400 [274.18 sec]: learning rate : 0.000031 loss : 0.365427 +[15:22:09.229] Epoch 149 Evaluation: +[15:25:08.623] average MSE: 0.11765317404759006 average PSNR: 22.297402271944733 average SSIM: 0.6161330934395568 +[15:25:56.541] iteration 86500 [47.89 sec]: learning rate : 0.000031 loss : 0.285462 +[15:26:43.923] iteration 86600 [95.28 sec]: learning rate : 0.000031 loss : 0.296322 +[15:27:31.245] iteration 86700 [142.60 sec]: learning rate : 0.000031 loss : 0.329720 +[15:28:18.682] iteration 86800 [190.03 sec]: learning rate : 0.000031 loss : 0.276039 +[15:29:06.141] iteration 86900 [237.49 sec]: learning rate : 0.000031 loss : 0.377784 +[15:29:42.103] Epoch 150 Evaluation: +[15:32:31.780] average MSE: 0.12345771336857272 average PSNR: 22.091489804297947 average SSIM: 0.6176200631452846 +[15:32:43.362] iteration 87000 [11.56 sec]: learning rate : 0.000031 loss : 0.363544 +[15:33:30.848] iteration 87100 [59.04 sec]: learning rate : 0.000031 loss : 0.258458 +[15:34:18.589] iteration 87200 [106.79 sec]: learning rate : 0.000031 loss : 0.347706 +[15:35:06.027] iteration 87300 [154.22 sec]: learning rate : 0.000031 loss : 0.380084 +[15:35:53.567] iteration 87400 [201.76 sec]: learning rate : 0.000031 loss : 0.258845 +[15:36:41.003] iteration 87500 [249.20 sec]: learning rate : 0.000031 loss : 0.330991 +[15:37:05.784] Epoch 151 Evaluation: +[15:39:57.969] average MSE: 0.12243091628484322 average PSNR: 22.127349930740763 average SSIM: 0.6215439950471744 +[15:40:21.031] iteration 87600 [23.04 sec]: learning rate : 0.000031 loss : 0.324958 +[15:41:08.747] iteration 87700 [70.75 sec]: learning rate : 0.000031 loss : 0.260603 +[15:41:56.050] iteration 87800 [118.06 sec]: learning rate : 0.000031 loss : 0.277552 +[15:42:43.452] iteration 87900 [165.46 sec]: learning rate : 0.000031 loss : 0.276832 +[15:43:31.161] iteration 88000 [213.17 sec]: learning rate : 0.000031 loss : 0.341453 +[15:44:18.486] iteration 88100 [260.49 sec]: learning rate : 0.000031 loss : 0.356607 +[15:44:31.736] Epoch 152 Evaluation: +[15:47:21.415] average MSE: 0.13270783493132005 average PSNR: 21.791482297877685 average SSIM: 0.6249814140296981 +[15:47:55.816] iteration 88200 [34.38 sec]: learning rate : 0.000031 loss : 0.307250 +[15:48:43.244] iteration 88300 [81.80 sec]: learning rate : 0.000031 loss : 0.387544 +[15:49:30.550] iteration 88400 [129.11 sec]: learning rate : 0.000031 loss : 0.358339 +[15:50:17.953] iteration 88500 [176.51 sec]: learning rate : 0.000031 loss : 0.319062 +[15:51:05.285] iteration 88600 [223.84 sec]: learning rate : 0.000031 loss : 0.268488 +[15:51:53.011] iteration 88700 [271.57 sec]: learning rate : 0.000031 loss : 0.316744 +[15:51:54.914] Epoch 153 Evaluation: +[15:54:48.249] average MSE: 0.12482344768745562 average PSNR: 22.045919867594478 average SSIM: 0.6199428578083004 +[15:55:34.296] iteration 88800 [46.02 sec]: learning rate : 0.000031 loss : 0.366686 +[15:56:21.812] iteration 88900 [93.54 sec]: learning rate : 0.000031 loss : 0.262144 +[15:57:09.335] iteration 89000 [141.06 sec]: learning rate : 0.000031 loss : 0.400408 +[15:57:56.780] iteration 89100 [188.51 sec]: learning rate : 0.000031 loss : 0.305967 +[15:58:44.122] iteration 89200 [235.85 sec]: learning rate : 0.000031 loss : 0.297104 +[15:59:22.059] Epoch 154 Evaluation: +[16:02:13.678] average MSE: 0.12307958271714783 average PSNR: 22.105306711084403 average SSIM: 0.6164786789404231 +[16:02:23.335] iteration 89300 [9.63 sec]: learning rate : 0.000031 loss : 0.239981 +[16:03:10.795] iteration 89400 [57.09 sec]: learning rate : 0.000031 loss : 0.304952 +[16:03:58.247] iteration 89500 [104.55 sec]: learning rate : 0.000031 loss : 0.326085 +[16:04:45.854] iteration 89600 [152.15 sec]: learning rate : 0.000031 loss : 0.275973 +[16:05:33.400] iteration 89700 [199.70 sec]: learning rate : 0.000031 loss : 0.264918 +[16:06:21.029] iteration 89800 [247.33 sec]: learning rate : 0.000031 loss : 0.357144 +[16:06:47.625] Epoch 155 Evaluation: +[16:09:46.047] average MSE: 0.13664452209717362 average PSNR: 21.6636629450531 average SSIM: 0.6192625098913552 +[16:10:07.159] iteration 89900 [21.09 sec]: learning rate : 0.000031 loss : 0.327559 +[16:10:54.503] iteration 90000 [68.43 sec]: learning rate : 0.000031 loss : 0.317068 +[16:11:41.985] iteration 90100 [115.91 sec]: learning rate : 0.000031 loss : 0.342997 +[16:12:29.444] iteration 90200 [163.37 sec]: learning rate : 0.000031 loss : 0.351348 +[16:13:16.881] iteration 90300 [210.81 sec]: learning rate : 0.000031 loss : 0.266184 +[16:14:04.361] iteration 90400 [258.29 sec]: learning rate : 0.000031 loss : 0.296411 +[16:14:19.604] Epoch 156 Evaluation: +[16:17:09.549] average MSE: 0.11998035183902726 average PSNR: 22.21375613913734 average SSIM: 0.6172019207421899 +[16:17:42.153] iteration 90500 [32.58 sec]: learning rate : 0.000031 loss : 0.301769 +[16:18:29.641] iteration 90600 [80.07 sec]: learning rate : 0.000031 loss : 0.269740 +[16:19:17.070] iteration 90700 [127.50 sec]: learning rate : 0.000031 loss : 0.289434 +[16:20:04.465] iteration 90800 [174.89 sec]: learning rate : 0.000031 loss : 0.290054 +[16:20:51.846] iteration 90900 [222.27 sec]: learning rate : 0.000031 loss : 0.388485 +[16:21:39.274] iteration 91000 [269.70 sec]: learning rate : 0.000031 loss : 0.340213 +[16:21:43.085] Epoch 157 Evaluation: +[16:24:32.767] average MSE: 0.1281582587214426 average PSNR: 21.932770853363035 average SSIM: 0.6155287232522485 +[16:25:16.516] iteration 91100 [43.72 sec]: learning rate : 0.000031 loss : 0.302904 +[16:26:03.947] iteration 91200 [91.16 sec]: learning rate : 0.000031 loss : 0.363741 +[16:26:51.343] iteration 91300 [138.55 sec]: learning rate : 0.000031 loss : 0.253265 +[16:27:38.877] iteration 91400 [186.09 sec]: learning rate : 0.000031 loss : 0.309345 +[16:28:26.362] iteration 91500 [233.57 sec]: learning rate : 0.000031 loss : 0.284986 +[16:29:06.129] Epoch 158 Evaluation: +[16:31:55.789] average MSE: 0.12730819035552335 average PSNR: 21.962113670679486 average SSIM: 0.6183060876248412 +[16:32:03.556] iteration 91600 [7.74 sec]: learning rate : 0.000031 loss : 0.380917 +[16:32:50.852] iteration 91700 [55.04 sec]: learning rate : 0.000031 loss : 0.360425 +[16:33:38.402] iteration 91800 [102.59 sec]: learning rate : 0.000031 loss : 0.361450 +[16:34:25.991] iteration 91900 [150.18 sec]: learning rate : 0.000031 loss : 0.237137 +[16:35:13.520] iteration 92000 [197.71 sec]: learning rate : 0.000031 loss : 0.402828 +[16:36:01.174] iteration 92100 [245.36 sec]: learning rate : 0.000031 loss : 0.246947 +[16:36:29.844] Epoch 159 Evaluation: +[16:39:19.728] average MSE: 0.11326400246584634 average PSNR: 22.461021964861757 average SSIM: 0.6131674320865776 +[16:39:38.847] iteration 92200 [19.10 sec]: learning rate : 0.000031 loss : 0.301424 +[16:40:26.324] iteration 92300 [66.57 sec]: learning rate : 0.000031 loss : 0.316202 +[16:41:13.769] iteration 92400 [114.02 sec]: learning rate : 0.000031 loss : 0.293201 +[16:42:01.254] iteration 92500 [161.50 sec]: learning rate : 0.000031 loss : 0.324746 +[16:42:48.787] iteration 92600 [209.03 sec]: learning rate : 0.000031 loss : 0.286732 +[16:43:36.224] iteration 92700 [256.47 sec]: learning rate : 0.000031 loss : 0.381584 +[16:43:53.264] Epoch 160 Evaluation: +[16:46:48.406] average MSE: 0.12768411062423776 average PSNR: 21.948167073649216 average SSIM: 0.6161080153480087 +[16:47:18.884] iteration 92800 [30.45 sec]: learning rate : 0.000031 loss : 0.385538 +[16:48:06.381] iteration 92900 [77.95 sec]: learning rate : 0.000031 loss : 0.378694 +[16:48:53.801] iteration 93000 [125.37 sec]: learning rate : 0.000031 loss : 0.321708 +[16:49:41.222] iteration 93100 [172.79 sec]: learning rate : 0.000031 loss : 0.357402 +[16:50:28.678] iteration 93200 [220.25 sec]: learning rate : 0.000031 loss : 0.365085 +[16:51:16.010] iteration 93300 [267.58 sec]: learning rate : 0.000031 loss : 0.360015 +[16:51:21.701] Epoch 161 Evaluation: +[16:54:11.774] average MSE: 0.11509076457416084 average PSNR: 22.39157178118001 average SSIM: 0.6111503007255945 +[16:54:53.950] iteration 93400 [42.15 sec]: learning rate : 0.000031 loss : 0.330148 +[16:55:41.352] iteration 93500 [89.55 sec]: learning rate : 0.000031 loss : 0.307994 +[16:56:28.698] iteration 93600 [136.90 sec]: learning rate : 0.000031 loss : 0.327909 +[16:57:16.164] iteration 93700 [184.37 sec]: learning rate : 0.000031 loss : 0.329026 +[16:58:03.642] iteration 93800 [231.84 sec]: learning rate : 0.000031 loss : 0.318435 +[16:58:45.296] Epoch 162 Evaluation: +[17:01:35.713] average MSE: 0.1280103895427097 average PSNR: 21.93556424924035 average SSIM: 0.61049719712299 +[17:01:41.583] iteration 93900 [5.85 sec]: learning rate : 0.000031 loss : 0.355424 +[17:02:29.050] iteration 94000 [53.31 sec]: learning rate : 0.000031 loss : 0.329053 +[17:03:16.715] iteration 94100 [100.98 sec]: learning rate : 0.000031 loss : 0.280618 +[17:04:04.035] iteration 94200 [148.30 sec]: learning rate : 0.000031 loss : 0.376147 +[17:04:51.565] iteration 94300 [195.83 sec]: learning rate : 0.000031 loss : 0.249542 +[17:05:39.191] iteration 94400 [243.45 sec]: learning rate : 0.000031 loss : 0.434659 +[17:06:09.720] Epoch 163 Evaluation: +[17:09:09.027] average MSE: 0.11721628861921018 average PSNR: 22.312617376022327 average SSIM: 0.6149499753507983 +[17:09:26.312] iteration 94500 [17.26 sec]: learning rate : 0.000031 loss : 0.266832 +[17:10:13.934] iteration 94600 [64.88 sec]: learning rate : 0.000031 loss : 0.357962 +[17:11:01.416] iteration 94700 [112.37 sec]: learning rate : 0.000031 loss : 0.275093 +[17:11:49.276] iteration 94800 [160.23 sec]: learning rate : 0.000031 loss : 0.335901 +[17:12:36.717] iteration 94900 [207.67 sec]: learning rate : 0.000031 loss : 0.348877 +[17:13:24.065] iteration 95000 [255.02 sec]: learning rate : 0.000031 loss : 0.264937 +[17:13:43.095] Epoch 164 Evaluation: +[17:16:33.812] average MSE: 0.11872706900086467 average PSNR: 22.257839029832436 average SSIM: 0.6142580357629688 +[17:17:02.383] iteration 95100 [28.55 sec]: learning rate : 0.000031 loss : 0.355394 +[17:17:49.856] iteration 95200 [76.02 sec]: learning rate : 0.000031 loss : 0.260242 +[17:18:37.184] iteration 95300 [123.35 sec]: learning rate : 0.000031 loss : 0.408497 +[17:19:24.583] iteration 95400 [170.75 sec]: learning rate : 0.000031 loss : 0.346032 +[17:20:12.023] iteration 95500 [218.19 sec]: learning rate : 0.000031 loss : 0.354747 +[17:20:59.800] iteration 95600 [265.96 sec]: learning rate : 0.000031 loss : 0.328503 +[17:21:07.398] Epoch 165 Evaluation: +[17:24:03.242] average MSE: 0.1145683274932215 average PSNR: 22.411950254819914 average SSIM: 0.6128289600648842 +[17:24:43.392] iteration 95700 [40.13 sec]: learning rate : 0.000031 loss : 0.299389 +[17:25:30.792] iteration 95800 [87.53 sec]: learning rate : 0.000031 loss : 0.346684 +[17:26:18.289] iteration 95900 [135.02 sec]: learning rate : 0.000031 loss : 0.297220 +[17:27:05.721] iteration 96000 [182.45 sec]: learning rate : 0.000031 loss : 0.237396 +[17:27:53.069] iteration 96100 [229.80 sec]: learning rate : 0.000031 loss : 0.377026 +[17:28:36.703] Epoch 166 Evaluation: +[17:31:32.870] average MSE: 0.1283302022963326 average PSNR: 21.929086804435634 average SSIM: 0.6170268136701186 +[17:31:36.859] iteration 96200 [3.96 sec]: learning rate : 0.000031 loss : 0.299242 +[17:32:24.402] iteration 96300 [51.51 sec]: learning rate : 0.000031 loss : 0.278231 +[17:33:11.803] iteration 96400 [98.91 sec]: learning rate : 0.000031 loss : 0.321602 +[17:33:59.323] iteration 96500 [146.43 sec]: learning rate : 0.000031 loss : 0.343905 +[17:34:46.839] iteration 96600 [193.97 sec]: learning rate : 0.000031 loss : 0.278860 +[17:35:34.288] iteration 96700 [241.39 sec]: learning rate : 0.000031 loss : 0.322538 +[17:36:06.602] Epoch 167 Evaluation: +[17:38:57.062] average MSE: 0.12863743990459955 average PSNR: 21.919514027240748 average SSIM: 0.6106977346021684 +[17:39:12.401] iteration 96800 [15.31 sec]: learning rate : 0.000031 loss : 0.299583 +[17:39:59.709] iteration 96900 [62.62 sec]: learning rate : 0.000031 loss : 0.342254 +[17:40:47.219] iteration 97000 [110.13 sec]: learning rate : 0.000031 loss : 0.316750 +[17:41:34.644] iteration 97100 [157.56 sec]: learning rate : 0.000031 loss : 0.308151 +[17:42:22.307] iteration 97200 [205.22 sec]: learning rate : 0.000031 loss : 0.283664 +[17:43:09.968] iteration 97300 [252.88 sec]: learning rate : 0.000031 loss : 0.279512 +[17:43:30.862] Epoch 168 Evaluation: +[17:46:29.940] average MSE: 0.11588628195654768 average PSNR: 22.36369289887033 average SSIM: 0.6136319608103035 +[17:46:57.141] iteration 97400 [27.18 sec]: learning rate : 0.000031 loss : 0.359947 +[17:47:44.434] iteration 97500 [74.47 sec]: learning rate : 0.000031 loss : 0.315896 +[17:48:31.835] iteration 97600 [121.87 sec]: learning rate : 0.000031 loss : 0.314271 +[17:49:19.224] iteration 97700 [169.26 sec]: learning rate : 0.000031 loss : 0.493240 +[17:50:06.656] iteration 97800 [216.69 sec]: learning rate : 0.000031 loss : 0.279838 +[17:50:54.251] iteration 97900 [264.29 sec]: learning rate : 0.000031 loss : 0.290301 +[17:51:03.722] Epoch 169 Evaluation: +[17:54:04.552] average MSE: 0.12135874583080065 average PSNR: 22.16400705870818 average SSIM: 0.5994423030178955 +[17:54:42.730] iteration 98000 [38.15 sec]: learning rate : 0.000031 loss : 0.377789 +[17:55:30.399] iteration 98100 [85.82 sec]: learning rate : 0.000031 loss : 0.313861 +[17:56:18.016] iteration 98200 [133.44 sec]: learning rate : 0.000031 loss : 0.287581 +[17:57:05.519] iteration 98300 [180.95 sec]: learning rate : 0.000031 loss : 0.218607 +[17:57:53.173] iteration 98400 [228.60 sec]: learning rate : 0.000031 loss : 0.244310 +[17:58:38.906] Epoch 170 Evaluation: +[18:01:38.187] average MSE: 0.11601832370660536 average PSNR: 22.357018779323795 average SSIM: 0.6112945462543597 +[18:01:40.309] iteration 98500 [2.10 sec]: learning rate : 0.000031 loss : 0.320956 +[18:02:27.746] iteration 98600 [49.54 sec]: learning rate : 0.000031 loss : 0.285553 +[18:03:15.313] iteration 98700 [97.10 sec]: learning rate : 0.000031 loss : 0.291123 +[18:04:02.833] iteration 98800 [144.62 sec]: learning rate : 0.000031 loss : 0.433745 +[18:04:50.578] iteration 98900 [192.37 sec]: learning rate : 0.000031 loss : 0.299969 +[18:05:38.344] iteration 99000 [240.13 sec]: learning rate : 0.000031 loss : 0.242963 +[18:06:12.576] Epoch 171 Evaluation: +[18:09:12.060] average MSE: 0.11836823130509996 average PSNR: 22.27247532139836 average SSIM: 0.6175527931049866 +[18:09:25.547] iteration 99100 [13.46 sec]: learning rate : 0.000031 loss : 0.263415 +[18:10:13.166] iteration 99200 [61.08 sec]: learning rate : 0.000031 loss : 0.305034 +[18:11:00.703] iteration 99300 [108.62 sec]: learning rate : 0.000031 loss : 0.304729 +[18:11:48.069] iteration 99400 [155.99 sec]: learning rate : 0.000031 loss : 0.341281 +[18:12:35.668] iteration 99500 [203.58 sec]: learning rate : 0.000031 loss : 0.289280 +[18:13:23.615] iteration 99600 [251.53 sec]: learning rate : 0.000031 loss : 0.288167 +[18:13:46.460] Epoch 172 Evaluation: +[18:16:42.841] average MSE: 0.12599396440110816 average PSNR: 22.002662499610565 average SSIM: 0.6116344788810175 +[18:17:07.639] iteration 99700 [24.77 sec]: learning rate : 0.000031 loss : 0.275722 +[18:17:55.149] iteration 99800 [72.28 sec]: learning rate : 0.000031 loss : 0.313841 +[18:18:42.585] iteration 99900 [119.72 sec]: learning rate : 0.000031 loss : 0.340052 +[18:19:29.903] iteration 100000 [167.04 sec]: learning rate : 0.000008 loss : 0.284032 +[18:19:30.061] save model to model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/iter_100000.pth +[18:19:30.550] Epoch 173 Evaluation: +[18:22:21.109] average MSE: 0.12252341071058222 average PSNR: 22.124341611762574 average SSIM: 0.6179250584283125 +[18:22:21.453] save model to model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/iter_100000.pth +===> Evaluate Metric <=== +Results +------------------------------------ +ColdDiffusion NMSE: 1.1987 ± 0.0580 +ColdDiffusion PSNR: 33.1367 ± 0.4128 +ColdDiffusion SSIM: 0.8699 ± 0.0097 +------------------------------------ +All NMSE: 1.1958 ± 0.1985 +All PSNR: 32.1098 ± 0.8062 +All SSIM: 0.8505 ± 0.0165 +------------------------------------ +Save Path: /home/v-qichen3/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/result_case/ \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/log/events.out.tfevents.1752550820.GCRSANDBOX133 b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/log/events.out.tfevents.1752550820.GCRSANDBOX133 new file mode 100644 index 0000000000000000000000000000000000000000..79693d4b9c2005c9af247dad0a2638cb0722c009 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t20_new_kspace_time/log/events.out.tfevents.1752550820.GCRSANDBOX133 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa058dd78a96840e4185c3e072a6e5abb2f74493bc6a7067adb9962ef11d97fc +size 40 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time/best_checkpoint.pth b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time/best_checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..3ed328145b9023dc095b1e1349834ce7ee223bb9 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time/best_checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:577c51d882585d289395f85508d38ddb6204eb8ec51490ad2dd7200c6a487d58 +size 56614874 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time/log.txt b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..97c8b14407a7ec6d8b983f06765f2cea8994e87d --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time/log.txt @@ -0,0 +1,1365 @@ +[20:35:28.707] Namespace(root_path='/home/v-qichen3/MRI_recon/data/m4raw', MRIDOWN='4X', low_field_SNR=0, phase='train', gpu='3', exp='FSMNet_m4raw_4x_lr5e-4', max_iterations=100000, batch_size=4, base_lr=0.0005, seed=1337, resume=None, relation_consistency='False', clip_grad='True', norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='random', CENTER_FRACTIONS=[0.08], ACCELERATIONS=[4], num_timesteps=30, image_size=240, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[20:41:50.371] iteration 100 [48.74 sec]: learning rate : 0.000500 loss : 0.627515 +[20:42:37.966] iteration 200 [96.33 sec]: learning rate : 0.000500 loss : 0.532187 +[20:43:26.212] iteration 300 [144.58 sec]: learning rate : 0.000500 loss : 0.356054 +[20:44:14.013] iteration 400 [192.38 sec]: learning rate : 0.000500 loss : 0.378413 +[20:45:01.486] iteration 500 [239.85 sec]: learning rate : 0.000500 loss : 0.405908 +[20:45:37.781] Epoch 0 Evaluation: +[20:50:09.686] average MSE: 0.05040822774134444 average PSNR: 26.017285826160233 average SSIM: 0.6883571496861531 +[20:50:21.309] iteration 600 [11.60 sec]: learning rate : 0.000500 loss : 0.354253 +[20:51:08.911] iteration 700 [59.20 sec]: learning rate : 0.000500 loss : 0.492697 +[20:51:56.822] iteration 800 [107.11 sec]: learning rate : 0.000500 loss : 0.481575 +[20:52:44.412] iteration 900 [154.70 sec]: learning rate : 0.000500 loss : 0.388467 +[20:53:32.033] iteration 1000 [202.32 sec]: learning rate : 0.000500 loss : 0.480756 +[20:54:19.958] iteration 1100 [250.25 sec]: learning rate : 0.000500 loss : 0.291611 +[20:54:44.733] Epoch 1 Evaluation: +[20:59:08.650] average MSE: 0.04617064575356903 average PSNR: 26.38918526400511 average SSIM: 0.7132331284152149 +[20:59:31.724] iteration 1200 [23.05 sec]: learning rate : 0.000500 loss : 0.423580 +[21:00:19.837] iteration 1300 [71.16 sec]: learning rate : 0.000500 loss : 0.395379 +[21:01:07.486] iteration 1400 [118.81 sec]: learning rate : 0.000500 loss : 0.355628 +[21:01:55.159] iteration 1500 [166.49 sec]: learning rate : 0.000500 loss : 0.372215 +[21:02:42.835] iteration 1600 [214.16 sec]: learning rate : 0.000500 loss : 0.429603 +[21:03:31.134] iteration 1700 [262.46 sec]: learning rate : 0.000500 loss : 0.446887 +[21:03:44.517] Epoch 2 Evaluation: +[21:08:13.126] average MSE: 0.045278943366565225 average PSNR: 26.473271005697608 average SSIM: 0.7165980924304832 +[21:08:47.711] iteration 1800 [34.56 sec]: learning rate : 0.000500 loss : 0.411159 +[21:09:35.258] iteration 1900 [82.11 sec]: learning rate : 0.000500 loss : 0.315845 +[21:10:23.825] iteration 2000 [130.68 sec]: learning rate : 0.000500 loss : 0.409471 +[21:11:12.062] iteration 2100 [178.92 sec]: learning rate : 0.000500 loss : 0.365295 +[21:11:59.747] iteration 2200 [226.60 sec]: learning rate : 0.000500 loss : 0.360535 +[21:12:47.486] iteration 2300 [274.34 sec]: learning rate : 0.000500 loss : 0.398745 +[21:12:49.408] Epoch 3 Evaluation: +[21:17:07.930] average MSE: 0.04592542072713889 average PSNR: 26.410146381605582 average SSIM: 0.7123885692673719 +[21:17:54.407] iteration 2400 [46.48 sec]: learning rate : 0.000500 loss : 0.418968 +[21:18:41.940] iteration 2500 [93.99 sec]: learning rate : 0.000500 loss : 0.408285 +[21:19:29.555] iteration 2600 [141.60 sec]: learning rate : 0.000500 loss : 0.359444 +[21:20:17.097] iteration 2700 [189.14 sec]: learning rate : 0.000500 loss : 0.381598 +[21:21:04.765] iteration 2800 [236.81 sec]: learning rate : 0.000500 loss : 0.426676 +[21:21:42.997] Epoch 4 Evaluation: +[21:26:00.593] average MSE: 0.04357969275083966 average PSNR: 26.647516301262016 average SSIM: 0.7273628216465439 +[21:26:10.301] iteration 2900 [9.69 sec]: learning rate : 0.000500 loss : 0.409787 +[21:26:57.970] iteration 3000 [57.35 sec]: learning rate : 0.000500 loss : 0.474151 +[21:27:46.207] iteration 3100 [105.59 sec]: learning rate : 0.000500 loss : 0.478769 +[21:28:34.206] iteration 3200 [153.59 sec]: learning rate : 0.000500 loss : 0.355551 +[21:29:21.806] iteration 3300 [201.19 sec]: learning rate : 0.000500 loss : 0.479880 +[21:30:09.519] iteration 3400 [248.90 sec]: learning rate : 0.000500 loss : 0.387730 +[21:30:36.152] Epoch 5 Evaluation: +[21:34:53.260] average MSE: 0.04363711034572926 average PSNR: 26.640064328044566 average SSIM: 0.7172458323050561 +[21:35:14.467] iteration 3500 [21.19 sec]: learning rate : 0.000500 loss : 0.353708 +[21:36:01.962] iteration 3600 [68.68 sec]: learning rate : 0.000500 loss : 0.366242 +[21:36:49.633] iteration 3700 [116.35 sec]: learning rate : 0.000500 loss : 0.336629 +[21:37:37.210] iteration 3800 [163.93 sec]: learning rate : 0.000500 loss : 0.460578 +[21:38:24.827] iteration 3900 [211.54 sec]: learning rate : 0.000500 loss : 0.381222 +[21:39:12.824] iteration 4000 [259.54 sec]: learning rate : 0.000500 loss : 0.412475 +[21:39:28.028] Epoch 6 Evaluation: +[21:43:42.462] average MSE: 0.051411275219452164 average PSNR: 25.91712051001767 average SSIM: 0.698873020525152 +[21:44:14.935] iteration 4100 [32.45 sec]: learning rate : 0.000500 loss : 0.326462 +[21:45:03.483] iteration 4200 [81.00 sec]: learning rate : 0.000500 loss : 0.422304 +[21:45:51.286] iteration 4300 [128.80 sec]: learning rate : 0.000500 loss : 0.388645 +[21:46:39.292] iteration 4400 [176.81 sec]: learning rate : 0.000500 loss : 0.333050 +[21:47:27.028] iteration 4500 [224.54 sec]: learning rate : 0.000500 loss : 0.418024 +[21:48:14.928] iteration 4600 [272.44 sec]: learning rate : 0.000500 loss : 0.306262 +[21:48:18.748] Epoch 7 Evaluation: +[21:52:47.371] average MSE: 0.045082423544961246 average PSNR: 26.508409879717487 average SSIM: 0.7158062284552004 +[21:53:31.220] iteration 4700 [43.83 sec]: learning rate : 0.000500 loss : 0.375724 +[21:54:18.860] iteration 4800 [91.47 sec]: learning rate : 0.000500 loss : 0.408971 +[21:55:06.438] iteration 4900 [139.04 sec]: learning rate : 0.000500 loss : 0.411843 +[21:55:54.764] iteration 5000 [187.37 sec]: learning rate : 0.000500 loss : 0.407588 +[21:56:42.344] iteration 5100 [234.95 sec]: learning rate : 0.000500 loss : 0.330621 +[21:57:22.240] Epoch 8 Evaluation: +[22:01:50.381] average MSE: 0.05085455712976653 average PSNR: 25.977030866498808 average SSIM: 0.6863325410340786 +[22:01:58.190] iteration 5200 [7.79 sec]: learning rate : 0.000500 loss : 0.356508 +[22:02:46.005] iteration 5300 [55.60 sec]: learning rate : 0.000500 loss : 0.411688 +[22:03:34.152] iteration 5400 [103.75 sec]: learning rate : 0.000500 loss : 0.446249 +[22:04:21.809] iteration 5500 [151.41 sec]: learning rate : 0.000500 loss : 0.347022 +[22:05:10.024] iteration 5600 [199.62 sec]: learning rate : 0.000500 loss : 0.321296 +[22:05:57.840] iteration 5700 [247.44 sec]: learning rate : 0.000500 loss : 0.359239 +[22:06:26.478] Epoch 9 Evaluation: +[22:10:45.167] average MSE: 0.046849328615205695 average PSNR: 26.340325866036544 average SSIM: 0.7027765827605098 +[22:11:04.392] iteration 5800 [19.20 sec]: learning rate : 0.000500 loss : 0.361839 +[22:11:52.101] iteration 5900 [66.91 sec]: learning rate : 0.000500 loss : 0.383517 +[22:12:40.218] iteration 6000 [115.03 sec]: learning rate : 0.000500 loss : 0.406519 +[22:13:27.676] iteration 6100 [162.49 sec]: learning rate : 0.000500 loss : 0.481683 +[22:14:15.265] iteration 6200 [210.08 sec]: learning rate : 0.000500 loss : 0.350752 +[22:15:03.000] iteration 6300 [257.81 sec]: learning rate : 0.000500 loss : 0.299752 +[22:15:20.150] Epoch 10 Evaluation: +[22:19:46.651] average MSE: 0.06448533164368918 average PSNR: 24.92590478682744 average SSIM: 0.6214960069991476 +[22:20:17.349] iteration 6400 [30.68 sec]: learning rate : 0.000500 loss : 0.395029 +[22:21:05.127] iteration 6500 [78.45 sec]: learning rate : 0.000500 loss : 0.405457 +[22:21:52.744] iteration 6600 [126.07 sec]: learning rate : 0.000500 loss : 0.341757 +[22:22:40.458] iteration 6700 [173.79 sec]: learning rate : 0.000500 loss : 0.362565 +[22:23:29.171] iteration 6800 [222.50 sec]: learning rate : 0.000500 loss : 0.267988 +[22:24:16.768] iteration 6900 [270.09 sec]: learning rate : 0.000500 loss : 0.260063 +[22:24:22.470] Epoch 11 Evaluation: +[22:28:47.366] average MSE: 0.053847312483235835 average PSNR: 25.725139401466777 average SSIM: 0.6627044387825767 +[22:29:30.536] iteration 7000 [43.15 sec]: learning rate : 0.000500 loss : 0.377538 +[22:30:18.309] iteration 7100 [90.92 sec]: learning rate : 0.000500 loss : 0.350466 +[22:31:05.973] iteration 7200 [138.59 sec]: learning rate : 0.000500 loss : 0.321862 +[22:31:53.766] iteration 7300 [186.38 sec]: learning rate : 0.000500 loss : 0.427665 +[22:32:41.553] iteration 7400 [234.16 sec]: learning rate : 0.000500 loss : 0.350136 +[22:33:24.036] Epoch 12 Evaluation: +[22:37:42.249] average MSE: 0.05550422989881967 average PSNR: 25.602783109805944 average SSIM: 0.6603208585163284 +[22:37:48.158] iteration 7500 [5.89 sec]: learning rate : 0.000500 loss : 0.333671 +[22:38:36.018] iteration 7600 [53.75 sec]: learning rate : 0.000500 loss : 0.250014 +[22:39:23.759] iteration 7700 [101.49 sec]: learning rate : 0.000500 loss : 0.325881 +[22:40:11.511] iteration 7800 [149.26 sec]: learning rate : 0.000500 loss : 0.317314 +[22:41:00.061] iteration 7900 [197.79 sec]: learning rate : 0.000500 loss : 0.319906 +[22:41:47.855] iteration 8000 [245.58 sec]: learning rate : 0.000500 loss : 0.272843 +[22:42:18.468] Epoch 13 Evaluation: +[22:46:44.707] average MSE: 0.07104059758919679 average PSNR: 24.497563840593056 average SSIM: 0.6237500183991462 +[22:47:01.981] iteration 8100 [17.25 sec]: learning rate : 0.000500 loss : 0.356630 +[22:47:50.692] iteration 8200 [65.97 sec]: learning rate : 0.000500 loss : 0.442398 +[22:48:38.298] iteration 8300 [113.57 sec]: learning rate : 0.000500 loss : 0.331265 +[22:49:26.118] iteration 8400 [161.39 sec]: learning rate : 0.000500 loss : 0.299925 +[22:50:13.924] iteration 8500 [209.20 sec]: learning rate : 0.000500 loss : 0.315553 +[22:51:01.620] iteration 8600 [256.89 sec]: learning rate : 0.000500 loss : 0.295324 +[22:51:20.761] Epoch 14 Evaluation: +[22:55:43.032] average MSE: 0.061203176265998305 average PSNR: 25.173504837986467 average SSIM: 0.6526919720829326 +[22:56:12.105] iteration 8700 [29.06 sec]: learning rate : 0.000500 loss : 0.366193 +[22:56:59.976] iteration 8800 [76.94 sec]: learning rate : 0.000500 loss : 0.321216 +[22:57:47.701] iteration 8900 [124.65 sec]: learning rate : 0.000500 loss : 0.397147 +[22:58:35.234] iteration 9000 [172.18 sec]: learning rate : 0.000500 loss : 0.425785 +[22:59:23.093] iteration 9100 [220.04 sec]: learning rate : 0.000500 loss : 0.393286 +[23:00:10.689] iteration 9200 [267.64 sec]: learning rate : 0.000500 loss : 0.361890 +[23:00:18.295] Epoch 15 Evaluation: +[23:04:34.256] average MSE: 0.05608452981625135 average PSNR: 25.549950912606846 average SSIM: 0.6569855042288532 +[23:05:14.376] iteration 9300 [40.09 sec]: learning rate : 0.000500 loss : 0.412878 +[23:06:02.459] iteration 9400 [88.18 sec]: learning rate : 0.000500 loss : 0.362261 +[23:06:50.096] iteration 9500 [135.81 sec]: learning rate : 0.000500 loss : 0.389729 +[23:07:37.670] iteration 9600 [183.39 sec]: learning rate : 0.000500 loss : 0.301824 +[23:08:25.771] iteration 9700 [231.49 sec]: learning rate : 0.000500 loss : 0.343341 +[23:09:09.634] Epoch 16 Evaluation: +[23:13:38.295] average MSE: 0.04871029199411792 average PSNR: 26.159996955661533 average SSIM: 0.683859804737677 +[23:13:42.259] iteration 9800 [3.94 sec]: learning rate : 0.000500 loss : 0.419562 +[23:14:30.169] iteration 9900 [51.85 sec]: learning rate : 0.000500 loss : 0.312834 +[23:15:17.906] iteration 10000 [99.59 sec]: learning rate : 0.000500 loss : 0.309299 +[23:16:05.499] iteration 10100 [147.18 sec]: learning rate : 0.000500 loss : 0.346864 +[23:16:52.858] iteration 10200 [194.54 sec]: learning rate : 0.000500 loss : 0.365183 +[23:17:40.749] iteration 10300 [242.43 sec]: learning rate : 0.000500 loss : 0.318103 +[23:18:13.342] Epoch 17 Evaluation: +[23:22:29.987] average MSE: 0.049893949450163254 average PSNR: 26.054930702308205 average SSIM: 0.6776341550275723 +[23:22:45.484] iteration 10400 [15.47 sec]: learning rate : 0.000500 loss : 0.434853 +[23:23:33.439] iteration 10500 [63.43 sec]: learning rate : 0.000500 loss : 0.414831 +[23:24:20.930] iteration 10600 [110.92 sec]: learning rate : 0.000500 loss : 0.424417 +[23:25:08.954] iteration 10700 [158.94 sec]: learning rate : 0.000500 loss : 0.309434 +[23:25:56.836] iteration 10800 [206.83 sec]: learning rate : 0.000500 loss : 0.411708 +[23:26:44.420] iteration 10900 [254.41 sec]: learning rate : 0.000500 loss : 0.370453 +[23:27:05.288] Epoch 18 Evaluation: +[23:31:27.591] average MSE: 0.055258512800036026 average PSNR: 25.61625943182992 average SSIM: 0.6590136172692033 +[23:31:54.420] iteration 11000 [26.81 sec]: learning rate : 0.000500 loss : 0.299403 +[23:32:42.032] iteration 11100 [74.42 sec]: learning rate : 0.000500 loss : 0.278782 +[23:33:30.007] iteration 11200 [122.39 sec]: learning rate : 0.000500 loss : 0.396928 +[23:34:17.435] iteration 11300 [169.82 sec]: learning rate : 0.000500 loss : 0.342244 +[23:35:04.925] iteration 11400 [217.31 sec]: learning rate : 0.000500 loss : 0.268799 +[23:35:52.949] iteration 11500 [265.34 sec]: learning rate : 0.000500 loss : 0.385652 +[23:36:02.425] Epoch 19 Evaluation: +[23:40:20.257] average MSE: 0.043238996777722256 average PSNR: 26.67054804118026 average SSIM: 0.7101235194227054 +[23:40:58.534] iteration 11600 [38.25 sec]: learning rate : 0.000500 loss : 0.419740 +[23:41:47.133] iteration 11700 [86.87 sec]: learning rate : 0.000500 loss : 0.281096 +[23:42:34.600] iteration 11800 [134.32 sec]: learning rate : 0.000500 loss : 0.340919 +[23:43:22.207] iteration 11900 [181.93 sec]: learning rate : 0.000500 loss : 0.344761 +[23:44:09.724] iteration 12000 [229.44 sec]: learning rate : 0.000500 loss : 0.292479 +[23:44:55.284] Epoch 20 Evaluation: +[23:49:18.366] average MSE: 0.04707938558430731 average PSNR: 26.3183843410823 average SSIM: 0.6928336228135012 +[23:49:20.441] iteration 12100 [2.05 sec]: learning rate : 0.000500 loss : 0.345593 +[23:50:08.061] iteration 12200 [49.67 sec]: learning rate : 0.000500 loss : 0.292931 +[23:50:55.644] iteration 12300 [97.26 sec]: learning rate : 0.000500 loss : 0.277525 +[23:51:43.194] iteration 12400 [144.81 sec]: learning rate : 0.000500 loss : 0.312610 +[23:52:30.726] iteration 12500 [192.34 sec]: learning rate : 0.000500 loss : 0.345185 +[23:53:18.559] iteration 12600 [240.17 sec]: learning rate : 0.000500 loss : 0.422942 +[23:53:53.230] Epoch 21 Evaluation: +[23:58:19.371] average MSE: 0.05297333056469738 average PSNR: 25.79621871065568 average SSIM: 0.6663884944600372 +[23:58:33.328] iteration 12700 [13.93 sec]: learning rate : 0.000500 loss : 0.411172 +[23:59:21.022] iteration 12800 [61.63 sec]: learning rate : 0.000500 loss : 0.359981 +[00:00:09.252] iteration 12900 [109.86 sec]: learning rate : 0.000500 loss : 0.302110 +[00:00:56.828] iteration 13000 [157.44 sec]: learning rate : 0.000500 loss : 0.311051 +[00:01:44.496] iteration 13100 [205.10 sec]: learning rate : 0.000500 loss : 0.312233 +[00:02:31.953] iteration 13200 [252.56 sec]: learning rate : 0.000500 loss : 0.335453 +[00:02:55.147] Epoch 22 Evaluation: +[00:07:20.740] average MSE: 0.05428217969555542 average PSNR: 25.722296109693296 average SSIM: 0.6855062109128651 +[00:07:45.690] iteration 13300 [24.93 sec]: learning rate : 0.000500 loss : 0.319434 +[00:08:33.372] iteration 13400 [72.61 sec]: learning rate : 0.000500 loss : 0.403086 +[00:09:20.902] iteration 13500 [120.14 sec]: learning rate : 0.000500 loss : 0.404759 +[00:10:08.986] iteration 13600 [168.22 sec]: learning rate : 0.000500 loss : 0.326808 +[00:10:57.388] iteration 13700 [216.63 sec]: learning rate : 0.000500 loss : 0.344020 +[00:11:44.774] iteration 13800 [264.01 sec]: learning rate : 0.000500 loss : 0.299164 +[00:11:56.146] Epoch 23 Evaluation: +[00:16:12.243] average MSE: 0.041717187130703916 average PSNR: 26.833206136819598 average SSIM: 0.7270994761319575 +[00:16:48.746] iteration 13900 [36.48 sec]: learning rate : 0.000500 loss : 0.301771 +[00:17:36.376] iteration 14000 [84.11 sec]: learning rate : 0.000500 loss : 0.318389 +[00:18:24.475] iteration 14100 [132.21 sec]: learning rate : 0.000500 loss : 0.302119 +[00:19:11.957] iteration 14200 [179.69 sec]: learning rate : 0.000500 loss : 0.311020 +[00:19:59.341] iteration 14300 [227.08 sec]: learning rate : 0.000500 loss : 0.353999 +[00:20:47.098] iteration 14400 [274.83 sec]: learning rate : 0.000500 loss : 0.370136 +[00:20:47.133] Epoch 24 Evaluation: +[00:25:01.874] average MSE: 0.056276024874737714 average PSNR: 25.509088389724276 average SSIM: 0.6976425235572509 +[00:25:50.275] iteration 14500 [48.38 sec]: learning rate : 0.000500 loss : 0.417182 +[00:26:38.101] iteration 14600 [96.20 sec]: learning rate : 0.000500 loss : 0.358262 +[00:27:25.756] iteration 14700 [143.86 sec]: learning rate : 0.000500 loss : 0.258740 +[00:28:13.484] iteration 14800 [191.59 sec]: learning rate : 0.000500 loss : 0.266346 +[00:29:01.146] iteration 14900 [239.25 sec]: learning rate : 0.000500 loss : 0.294133 +[00:29:37.476] Epoch 25 Evaluation: +[00:34:00.288] average MSE: 0.053499076879608636 average PSNR: 25.738986817020148 average SSIM: 0.6649466818616716 +[00:34:11.919] iteration 15000 [11.61 sec]: learning rate : 0.000500 loss : 0.303730 +[00:34:59.694] iteration 15100 [59.38 sec]: learning rate : 0.000500 loss : 0.358511 +[00:35:47.103] iteration 15200 [106.79 sec]: learning rate : 0.000500 loss : 0.427657 +[00:36:34.977] iteration 15300 [154.67 sec]: learning rate : 0.000500 loss : 0.409343 +[00:37:22.372] iteration 15400 [202.06 sec]: learning rate : 0.000500 loss : 0.366470 +[00:38:10.441] iteration 15500 [250.13 sec]: learning rate : 0.000500 loss : 0.274518 +[00:38:35.179] Epoch 26 Evaluation: +[00:42:58.028] average MSE: 0.07254499521011826 average PSNR: 24.403151267330387 average SSIM: 0.626357795192006 +[00:43:21.244] iteration 15600 [23.19 sec]: learning rate : 0.000500 loss : 0.347581 +[00:44:08.639] iteration 15700 [70.59 sec]: learning rate : 0.000500 loss : 0.373327 +[00:44:56.071] iteration 15800 [118.02 sec]: learning rate : 0.000500 loss : 0.349536 +[00:45:43.546] iteration 15900 [165.49 sec]: learning rate : 0.000500 loss : 0.308357 +[00:46:30.910] iteration 16000 [212.86 sec]: learning rate : 0.000500 loss : 0.376383 +[00:47:18.406] iteration 16100 [260.35 sec]: learning rate : 0.000500 loss : 0.404555 +[00:47:31.716] Epoch 27 Evaluation: +[00:51:50.805] average MSE: 0.0711067865963918 average PSNR: 24.49399901763127 average SSIM: 0.6165216025864971 +[00:52:25.414] iteration 16200 [34.59 sec]: learning rate : 0.000500 loss : 0.348619 +[00:53:12.835] iteration 16300 [82.01 sec]: learning rate : 0.000500 loss : 0.321446 +[00:54:00.449] iteration 16400 [129.62 sec]: learning rate : 0.000500 loss : 0.351143 +[00:54:48.810] iteration 16500 [178.00 sec]: learning rate : 0.000500 loss : 0.270938 +[00:55:36.648] iteration 16600 [225.82 sec]: learning rate : 0.000500 loss : 0.310949 +[00:56:24.347] iteration 16700 [273.52 sec]: learning rate : 0.000500 loss : 0.357042 +[00:56:26.261] Epoch 28 Evaluation: +[01:00:47.458] average MSE: 0.05982177785936191 average PSNR: 25.26817113449225 average SSIM: 0.6453101434050336 +[01:01:33.322] iteration 16800 [45.84 sec]: learning rate : 0.000500 loss : 0.279048 +[01:02:20.963] iteration 16900 [93.48 sec]: learning rate : 0.000500 loss : 0.369734 +[01:03:08.639] iteration 17000 [141.16 sec]: learning rate : 0.000500 loss : 0.352296 +[01:03:56.000] iteration 17100 [188.52 sec]: learning rate : 0.000500 loss : 0.338605 +[01:04:43.429] iteration 17200 [235.95 sec]: learning rate : 0.000500 loss : 0.340661 +[01:05:21.446] Epoch 29 Evaluation: +[01:09:43.448] average MSE: 0.05359009193546284 average PSNR: 25.75556993457597 average SSIM: 0.6636031315282328 +[01:09:53.243] iteration 17300 [9.77 sec]: learning rate : 0.000500 loss : 0.346827 +[01:10:41.005] iteration 17400 [57.53 sec]: learning rate : 0.000500 loss : 0.354437 +[01:11:28.891] iteration 17500 [105.42 sec]: learning rate : 0.000500 loss : 0.323755 +[01:12:16.449] iteration 17600 [152.98 sec]: learning rate : 0.000500 loss : 0.308290 +[01:13:04.167] iteration 17700 [200.70 sec]: learning rate : 0.000500 loss : 0.311435 +[01:13:51.726] iteration 17800 [248.26 sec]: learning rate : 0.000500 loss : 0.304004 +[01:14:18.325] Epoch 30 Evaluation: +[01:18:34.560] average MSE: 0.06762057535891831 average PSNR: 24.71193460830237 average SSIM: 0.6213984183526334 +[01:18:56.218] iteration 17900 [21.63 sec]: learning rate : 0.000500 loss : 0.361088 +[01:19:43.781] iteration 18000 [69.20 sec]: learning rate : 0.000500 loss : 0.413711 +[01:20:31.319] iteration 18100 [116.73 sec]: learning rate : 0.000500 loss : 0.352314 +[01:21:18.697] iteration 18200 [164.11 sec]: learning rate : 0.000500 loss : 0.388572 +[01:22:06.164] iteration 18300 [211.58 sec]: learning rate : 0.000500 loss : 0.392733 +[01:22:54.254] iteration 18400 [259.67 sec]: learning rate : 0.000500 loss : 0.354726 +[01:23:09.469] Epoch 31 Evaluation: +[01:27:35.915] average MSE: 0.059019939702786434 average PSNR: 25.31367885561547 average SSIM: 0.6464441693817459 +[01:28:08.725] iteration 18500 [32.79 sec]: learning rate : 0.000500 loss : 0.302703 +[01:28:56.334] iteration 18600 [80.40 sec]: learning rate : 0.000500 loss : 0.310927 +[01:29:43.994] iteration 18700 [128.06 sec]: learning rate : 0.000500 loss : 0.299435 +[01:30:31.565] iteration 18800 [175.63 sec]: learning rate : 0.000500 loss : 0.281799 +[01:31:19.751] iteration 18900 [223.81 sec]: learning rate : 0.000500 loss : 0.356086 +[01:32:07.361] iteration 19000 [271.45 sec]: learning rate : 0.000500 loss : 0.305142 +[01:32:11.167] Epoch 32 Evaluation: +[01:36:27.970] average MSE: 0.04794824567462472 average PSNR: 26.231868125249274 average SSIM: 0.6854452884807176 +[01:37:12.535] iteration 19100 [44.54 sec]: learning rate : 0.000500 loss : 0.348708 +[01:38:00.055] iteration 19200 [92.06 sec]: learning rate : 0.000500 loss : 0.424087 +[01:38:47.404] iteration 19300 [139.41 sec]: learning rate : 0.000500 loss : 0.346287 +[01:39:35.392] iteration 19400 [187.40 sec]: learning rate : 0.000500 loss : 0.622166 +[01:40:23.072] iteration 19500 [235.08 sec]: learning rate : 0.000500 loss : 0.309631 +[01:41:03.443] Epoch 33 Evaluation: +[01:45:35.372] average MSE: 0.05149614100370233 average PSNR: 25.919712199344744 average SSIM: 0.6678509156048575 +[01:45:43.151] iteration 19600 [7.76 sec]: learning rate : 0.000500 loss : 0.313457 +[01:46:30.646] iteration 19700 [55.25 sec]: learning rate : 0.000500 loss : 0.464844 +[01:47:18.109] iteration 19800 [102.71 sec]: learning rate : 0.000500 loss : 0.394760 +[01:48:05.867] iteration 19900 [150.47 sec]: learning rate : 0.000500 loss : 0.339991 +[01:48:53.352] iteration 20000 [197.96 sec]: learning rate : 0.000125 loss : 0.537183 +[01:48:53.509] save model to model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time/iter_20000.pth +[01:49:41.392] iteration 20100 [246.00 sec]: learning rate : 0.000250 loss : 0.445133 +[01:50:09.818] Epoch 34 Evaluation: +[01:54:25.990] average MSE: 0.0646095687927448 average PSNR: 24.90529330640544 average SSIM: 0.6156210272696034 +[01:54:45.135] iteration 20200 [19.12 sec]: learning rate : 0.000250 loss : 0.300140 +[01:55:33.827] iteration 20300 [67.82 sec]: learning rate : 0.000250 loss : 0.321831 +[01:56:21.754] iteration 20400 [115.74 sec]: learning rate : 0.000250 loss : 0.336544 +[01:57:09.258] iteration 20500 [163.25 sec]: learning rate : 0.000250 loss : 0.379194 +[01:57:56.768] iteration 20600 [210.76 sec]: learning rate : 0.000250 loss : 0.345535 +[01:58:44.158] iteration 20700 [258.15 sec]: learning rate : 0.000250 loss : 0.339839 +[01:59:01.347] Epoch 35 Evaluation: +[02:03:17.212] average MSE: 0.06116926602917821 average PSNR: 25.163194324708144 average SSIM: 0.6343175009155524 +[02:03:47.710] iteration 20800 [30.47 sec]: learning rate : 0.000250 loss : 0.396276 +[02:04:35.336] iteration 20900 [78.10 sec]: learning rate : 0.000250 loss : 0.385397 +[02:05:22.820] iteration 21000 [125.59 sec]: learning rate : 0.000250 loss : 0.294728 +[02:06:10.436] iteration 21100 [173.20 sec]: learning rate : 0.000250 loss : 0.432991 +[02:06:58.023] iteration 21200 [220.79 sec]: learning rate : 0.000250 loss : 0.290106 +[02:07:46.517] iteration 21300 [269.29 sec]: learning rate : 0.000250 loss : 17.507614 +[02:07:52.235] Epoch 36 Evaluation: +[02:12:11.394] average MSE: 0.09244395039994226 average PSNR: 23.344384048763022 average SSIM: 0.6104131326831591 +[02:12:54.256] iteration 21400 [42.84 sec]: learning rate : 0.000250 loss : 0.340580 +[02:13:42.122] iteration 21500 [90.70 sec]: learning rate : 0.000250 loss : 0.264067 +[02:14:29.603] iteration 21600 [138.19 sec]: learning rate : 0.000250 loss : 1.482431 +[02:15:17.084] iteration 21700 [185.67 sec]: learning rate : 0.000250 loss : 0.430851 +[02:16:04.418] iteration 21800 [233.00 sec]: learning rate : 0.000250 loss : 0.356972 +[02:16:46.194] Epoch 37 Evaluation: +[02:21:01.289] average MSE: 0.06034406882215407 average PSNR: 25.222095567735813 average SSIM: 0.636740070248147 +[02:21:07.162] iteration 21900 [5.85 sec]: learning rate : 0.000250 loss : 0.389787 +[02:21:54.691] iteration 22000 [53.38 sec]: learning rate : 0.000250 loss : 0.247607 +[02:22:42.140] iteration 22100 [100.83 sec]: learning rate : 0.000250 loss : 0.310513 +[02:23:29.654] iteration 22200 [148.34 sec]: learning rate : 0.000250 loss : 0.357760 +[02:24:17.659] iteration 22300 [196.35 sec]: learning rate : 0.000250 loss : 1.538281 +[02:25:05.012] iteration 22400 [243.70 sec]: learning rate : 0.000250 loss : 0.258267 +[02:25:36.477] Epoch 38 Evaluation: +[02:29:51.627] average MSE: 0.06629204934649673 average PSNR: 24.809782644794865 average SSIM: 0.6222690526893299 +[02:30:08.862] iteration 22500 [17.21 sec]: learning rate : 0.000250 loss : 0.357262 +[02:30:56.225] iteration 22600 [64.57 sec]: learning rate : 0.000250 loss : 0.358894 +[02:31:44.123] iteration 22700 [112.47 sec]: learning rate : 0.000250 loss : 0.302297 +[02:32:31.627] iteration 22800 [159.98 sec]: learning rate : 0.000250 loss : 0.313178 +[02:33:19.479] iteration 22900 [207.83 sec]: learning rate : 0.000250 loss : 0.270112 +[02:34:07.066] iteration 23000 [255.42 sec]: learning rate : 0.000250 loss : 0.267751 +[02:34:26.106] Epoch 39 Evaluation: +[02:38:41.973] average MSE: 0.0608202171855354 average PSNR: 25.18860133011461 average SSIM: 0.6405425801591559 +[02:39:10.742] iteration 23100 [28.75 sec]: learning rate : 0.000250 loss : 0.289142 +[02:39:58.152] iteration 23200 [76.15 sec]: learning rate : 0.000250 loss : 0.277579 +[02:40:46.570] iteration 23300 [124.58 sec]: learning rate : 0.000250 loss : 0.419677 +[02:41:34.109] iteration 23400 [172.11 sec]: learning rate : 0.000250 loss : 0.385234 +[02:42:21.475] iteration 23500 [219.48 sec]: learning rate : 0.000250 loss : 0.394707 +[02:43:09.110] iteration 23600 [267.11 sec]: learning rate : 0.000250 loss : 0.337297 +[02:43:16.718] Epoch 40 Evaluation: +[02:47:34.668] average MSE: 0.05693904428703073 average PSNR: 25.48504024862695 average SSIM: 0.6488939643017126 +[02:48:15.130] iteration 23700 [40.44 sec]: learning rate : 0.000250 loss : 0.335183 +[02:49:02.487] iteration 23800 [87.80 sec]: learning rate : 0.000250 loss : 0.381867 +[02:49:50.449] iteration 23900 [135.76 sec]: learning rate : 0.000250 loss : 0.418163 +[02:50:37.947] iteration 24000 [183.26 sec]: learning rate : 0.000250 loss : 0.324235 +[02:51:25.564] iteration 24100 [230.87 sec]: learning rate : 0.000250 loss : 0.300341 +[02:52:09.766] Epoch 41 Evaluation: +[02:56:24.710] average MSE: 0.04880702297104693 average PSNR: 26.152304019474357 average SSIM: 0.6893403278938838 +[02:56:28.689] iteration 24200 [3.96 sec]: learning rate : 0.000250 loss : 0.485011 +[02:57:16.046] iteration 24300 [51.31 sec]: learning rate : 0.000250 loss : 0.285162 +[02:58:04.066] iteration 24400 [99.33 sec]: learning rate : 0.000250 loss : 0.352271 +[02:58:51.543] iteration 24500 [146.81 sec]: learning rate : 0.000250 loss : 0.314337 +[02:59:38.955] iteration 24600 [194.22 sec]: learning rate : 0.000250 loss : 0.316828 +[03:00:26.471] iteration 24700 [241.74 sec]: learning rate : 0.000250 loss : 0.322173 +[03:00:58.820] Epoch 42 Evaluation: +[03:05:15.629] average MSE: 0.05951658504933196 average PSNR: 25.283601957100245 average SSIM: 0.639879319329889 +[03:05:30.985] iteration 24800 [15.33 sec]: learning rate : 0.000250 loss : 0.426388 +[03:06:18.550] iteration 24900 [62.90 sec]: learning rate : 0.000250 loss : 0.374429 +[03:07:06.184] iteration 25000 [110.53 sec]: learning rate : 0.000250 loss : 0.371113 +[03:07:54.174] iteration 25100 [158.52 sec]: learning rate : 0.000250 loss : 0.287603 +[03:08:42.433] iteration 25200 [206.78 sec]: learning rate : 0.000250 loss : 0.445980 +[03:09:29.922] iteration 25300 [254.27 sec]: learning rate : 0.000250 loss : 0.368994 +[03:09:50.780] Epoch 43 Evaluation: +[03:14:14.395] average MSE: 0.08166925651965407 average PSNR: 23.895785133267648 average SSIM: 0.6490213497535168 +[03:14:41.155] iteration 25400 [26.74 sec]: learning rate : 0.000250 loss : 0.344556 +[03:15:28.656] iteration 25500 [74.24 sec]: learning rate : 0.000250 loss : 0.283585 +[03:16:16.102] iteration 25600 [121.68 sec]: learning rate : 0.000250 loss : 0.305686 +[03:17:03.481] iteration 25700 [169.06 sec]: learning rate : 0.000250 loss : 0.319275 +[03:17:50.980] iteration 25800 [216.56 sec]: learning rate : 0.000250 loss : 0.245870 +[03:18:38.963] iteration 25900 [264.54 sec]: learning rate : 0.000250 loss : 0.337372 +[03:18:48.528] Epoch 44 Evaluation: +[03:23:14.027] average MSE: 0.05590563570600505 average PSNR: 25.56091985784383 average SSIM: 0.6499425462695825 +[03:23:52.110] iteration 26000 [38.06 sec]: learning rate : 0.000250 loss : 0.383055 +[03:24:39.616] iteration 26100 [85.57 sec]: learning rate : 0.000250 loss : 0.244987 +[03:25:27.708] iteration 26200 [133.66 sec]: learning rate : 0.000250 loss : 0.319933 +[03:26:16.050] iteration 26300 [182.00 sec]: learning rate : 0.000250 loss : 0.278468 +[03:27:03.567] iteration 26400 [229.52 sec]: learning rate : 0.000250 loss : 0.259083 +[03:27:49.155] Epoch 45 Evaluation: +[03:32:07.715] average MSE: 0.054117050715113184 average PSNR: 25.708148947173513 average SSIM: 0.6562340397500518 +[03:32:09.800] iteration 26500 [2.06 sec]: learning rate : 0.000250 loss : 0.372330 +[03:32:57.337] iteration 26600 [49.60 sec]: learning rate : 0.000250 loss : 0.304797 +[03:33:45.182] iteration 26700 [97.45 sec]: learning rate : 0.000250 loss : 0.291486 +[03:34:32.598] iteration 26800 [144.86 sec]: learning rate : 0.000250 loss : 0.346403 +[03:35:20.147] iteration 26900 [192.41 sec]: learning rate : 0.000250 loss : 0.329226 +[03:36:07.543] iteration 27000 [239.80 sec]: learning rate : 0.000250 loss : 0.309370 +[03:36:42.276] Epoch 46 Evaluation: +[03:41:07.816] average MSE: 0.07098429180794606 average PSNR: 24.504908585826733 average SSIM: 0.6115226326719956 +[03:41:21.311] iteration 27100 [13.47 sec]: learning rate : 0.000250 loss : 0.406992 +[03:42:09.406] iteration 27200 [61.57 sec]: learning rate : 0.000250 loss : 0.312728 +[03:42:56.937] iteration 27300 [109.10 sec]: learning rate : 0.000250 loss : 0.279278 +[03:43:44.614] iteration 27400 [156.78 sec]: learning rate : 0.000250 loss : 0.346314 +[03:44:32.825] iteration 27500 [204.99 sec]: learning rate : 0.000250 loss : 0.331734 +[03:45:20.301] iteration 27600 [252.46 sec]: learning rate : 0.000250 loss : 0.313653 +[03:45:43.132] Epoch 47 Evaluation: +[03:49:59.153] average MSE: 0.05813297119656267 average PSNR: 25.389656589640197 average SSIM: 0.644074212563737 +[03:50:24.031] iteration 27700 [24.85 sec]: learning rate : 0.000250 loss : 0.427651 +[03:51:11.658] iteration 27800 [72.48 sec]: learning rate : 0.000250 loss : 0.350782 +[03:51:59.110] iteration 27900 [119.93 sec]: learning rate : 0.000250 loss : 0.500601 +[03:52:46.671] iteration 28000 [167.49 sec]: learning rate : 0.000250 loss : 0.312033 +[03:53:34.498] iteration 28100 [215.32 sec]: learning rate : 0.000250 loss : 0.325483 +[03:54:22.080] iteration 28200 [262.90 sec]: learning rate : 0.000250 loss : 0.308890 +[03:54:33.474] Epoch 48 Evaluation: +[03:58:59.773] average MSE: 0.048461643666123454 average PSNR: 26.191791524001275 average SSIM: 0.683283975446972 +[03:59:36.275] iteration 28300 [36.48 sec]: learning rate : 0.000250 loss : 0.316302 +[04:00:23.893] iteration 28400 [84.10 sec]: learning rate : 0.000250 loss : 0.305553 +[04:01:11.437] iteration 28500 [131.64 sec]: learning rate : 0.000250 loss : 0.330243 +[04:01:59.441] iteration 28600 [179.64 sec]: learning rate : 0.000250 loss : 0.288550 +[04:02:46.801] iteration 28700 [227.00 sec]: learning rate : 0.000250 loss : 0.352574 +[04:03:34.565] iteration 28800 [274.77 sec]: learning rate : 0.000250 loss : 0.314999 +[04:03:34.600] Epoch 49 Evaluation: +[04:07:54.386] average MSE: 0.04506105249184935 average PSNR: 26.509660407082507 average SSIM: 0.7080595494403713 +[04:08:42.408] iteration 28900 [48.00 sec]: learning rate : 0.000250 loss : 0.345224 +[04:09:29.760] iteration 29000 [95.35 sec]: learning rate : 0.000250 loss : 0.358849 +[04:10:17.386] iteration 29100 [142.98 sec]: learning rate : 0.000250 loss : 0.214744 +[04:11:05.195] iteration 29200 [190.79 sec]: learning rate : 0.000250 loss : 0.294788 +[04:11:52.701] iteration 29300 [238.29 sec]: learning rate : 0.000250 loss : 0.329813 +[04:12:28.859] Epoch 50 Evaluation: +[04:16:46.360] average MSE: 0.05785445247895306 average PSNR: 25.414398167218543 average SSIM: 0.6455420917979633 +[04:16:57.923] iteration 29400 [11.54 sec]: learning rate : 0.000250 loss : 0.250310 +[04:17:45.309] iteration 29500 [58.93 sec]: learning rate : 0.000250 loss : 0.381236 +[04:18:33.290] iteration 29600 [106.91 sec]: learning rate : 0.000250 loss : 0.379517 +[04:19:20.936] iteration 29700 [154.55 sec]: learning rate : 0.000250 loss : 0.371370 +[04:20:08.461] iteration 29800 [202.08 sec]: learning rate : 0.000250 loss : 0.341354 +[04:20:56.489] iteration 29900 [250.11 sec]: learning rate : 0.000250 loss : 0.235264 +[04:21:21.545] Epoch 51 Evaluation: +[04:25:38.257] average MSE: 0.04778628465410208 average PSNR: 26.24411318081719 average SSIM: 0.6917432316239714 +[04:26:02.022] iteration 30000 [23.74 sec]: learning rate : 0.000250 loss : 0.320936 +[04:26:50.293] iteration 30100 [72.01 sec]: learning rate : 0.000250 loss : 0.358203 +[04:27:37.793] iteration 30200 [119.51 sec]: learning rate : 0.000250 loss : 0.336851 +[04:28:25.252] iteration 30300 [166.97 sec]: learning rate : 0.000250 loss : 0.332415 +[04:29:12.784] iteration 30400 [214.50 sec]: learning rate : 0.000250 loss : 0.376071 +[04:30:00.349] iteration 30500 [262.07 sec]: learning rate : 0.000250 loss : 0.338815 +[04:30:13.624] Epoch 52 Evaluation: +[04:34:30.575] average MSE: 0.0474882748894975 average PSNR: 26.28275212878474 average SSIM: 0.685192234541638 +[04:35:05.025] iteration 30600 [34.43 sec]: learning rate : 0.000250 loss : 0.313181 +[04:35:52.744] iteration 30700 [82.15 sec]: learning rate : 0.000250 loss : 0.264347 +[04:36:40.387] iteration 30800 [129.79 sec]: learning rate : 0.000250 loss : 0.356822 +[04:37:27.908] iteration 30900 [177.31 sec]: learning rate : 0.000250 loss : 0.306638 +[04:38:16.002] iteration 31000 [225.41 sec]: learning rate : 0.000250 loss : 0.242501 +[04:39:04.073] iteration 31100 [273.48 sec]: learning rate : 0.000250 loss : 0.995845 +[04:39:05.976] Epoch 53 Evaluation: +[04:43:20.924] average MSE: 0.08866399449436722 average PSNR: 23.524328501552326 average SSIM: 0.6021311117911008 +[04:44:07.233] iteration 31200 [46.28 sec]: learning rate : 0.000250 loss : 0.288043 +[04:44:55.162] iteration 31300 [94.21 sec]: learning rate : 0.000250 loss : 0.320354 +[04:45:42.529] iteration 31400 [141.58 sec]: learning rate : 0.000250 loss : 0.275535 +[04:46:30.014] iteration 31500 [189.06 sec]: learning rate : 0.000250 loss : 0.319112 +[04:47:17.475] iteration 31600 [236.52 sec]: learning rate : 0.000250 loss : 0.340211 +[04:47:55.400] Epoch 54 Evaluation: +[04:52:12.702] average MSE: 0.043549690057280335 average PSNR: 26.646155734354736 average SSIM: 0.7045564120838441 +[04:52:22.409] iteration 31700 [9.68 sec]: learning rate : 0.000250 loss : 0.347937 +[04:53:10.098] iteration 31800 [57.37 sec]: learning rate : 0.000250 loss : 0.339213 +[04:53:57.678] iteration 31900 [104.95 sec]: learning rate : 0.000250 loss : 0.366442 +[04:54:45.593] iteration 32000 [152.87 sec]: learning rate : 0.000250 loss : 0.318344 +[04:55:33.177] iteration 32100 [200.45 sec]: learning rate : 0.000250 loss : 0.325259 +[04:56:21.438] iteration 32200 [248.71 sec]: learning rate : 0.000250 loss : 0.343579 +[04:56:48.502] Epoch 55 Evaluation: +[05:01:13.589] average MSE: 0.06174552234691121 average PSNR: 25.11986656609731 average SSIM: 0.6335919269968187 +[05:01:34.619] iteration 32300 [21.01 sec]: learning rate : 0.000250 loss : 0.365205 +[05:02:22.121] iteration 32400 [68.51 sec]: learning rate : 0.000250 loss : 0.381824 +[05:03:10.136] iteration 32500 [116.52 sec]: learning rate : 0.000250 loss : 0.298822 +[05:03:58.034] iteration 32600 [164.42 sec]: learning rate : 0.000250 loss : 0.394535 +[05:04:45.527] iteration 32700 [211.91 sec]: learning rate : 0.000250 loss : 0.387997 +[05:05:32.958] iteration 32800 [259.35 sec]: learning rate : 0.000250 loss : 0.332387 +[05:05:48.411] Epoch 56 Evaluation: +[05:10:03.458] average MSE: 0.07466106069880955 average PSNR: 24.28212491263445 average SSIM: 0.6038679358528719 +[05:10:35.856] iteration 32900 [32.37 sec]: learning rate : 0.000250 loss : 0.332766 +[05:11:24.310] iteration 33000 [80.85 sec]: learning rate : 0.000250 loss : 0.350587 +[05:12:11.743] iteration 33100 [128.26 sec]: learning rate : 0.000250 loss : 0.360450 +[05:12:59.204] iteration 33200 [175.72 sec]: learning rate : 0.000250 loss : 0.307951 +[05:13:46.799] iteration 33300 [223.32 sec]: learning rate : 0.000250 loss : 0.324721 +[05:14:34.177] iteration 33400 [270.70 sec]: learning rate : 0.000250 loss : 0.311590 +[05:14:37.986] Epoch 57 Evaluation: +[05:18:53.341] average MSE: 0.05154045890228058 average PSNR: 25.91789993950643 average SSIM: 0.665724542696521 +[05:19:37.264] iteration 33500 [43.90 sec]: learning rate : 0.000250 loss : 0.357261 +[05:20:24.711] iteration 33600 [91.35 sec]: learning rate : 0.000250 loss : 0.408265 +[05:21:12.595] iteration 33700 [139.23 sec]: learning rate : 0.000250 loss : 0.398267 +[05:22:00.048] iteration 33800 [186.68 sec]: learning rate : 0.000250 loss : 0.349246 +[05:22:47.742] iteration 33900 [234.38 sec]: learning rate : 0.000250 loss : 0.281349 +[05:23:27.704] Epoch 58 Evaluation: +[05:27:46.150] average MSE: 0.047224091648790406 average PSNR: 26.301490498474234 average SSIM: 0.6851836889576953 +[05:27:53.946] iteration 34000 [7.77 sec]: learning rate : 0.000250 loss : 0.306488 +[05:28:42.121] iteration 34100 [55.95 sec]: learning rate : 0.000250 loss : 0.373481 +[05:29:29.511] iteration 34200 [103.34 sec]: learning rate : 0.000250 loss : 0.362403 +[05:30:17.160] iteration 34300 [150.99 sec]: learning rate : 0.000250 loss : 0.318961 +[05:31:04.631] iteration 34400 [198.46 sec]: learning rate : 0.000250 loss : 0.269439 +[05:31:51.970] iteration 34500 [245.80 sec]: learning rate : 0.000250 loss : 0.332694 +[05:32:20.483] Epoch 59 Evaluation: +[05:36:38.410] average MSE: 0.07870567857907795 average PSNR: 24.04597783562476 average SSIM: 0.5947310910725617 +[05:36:57.634] iteration 34600 [19.20 sec]: learning rate : 0.000250 loss : 0.290414 +[05:37:45.365] iteration 34700 [66.93 sec]: learning rate : 0.000250 loss : 0.487752 +[05:38:32.885] iteration 34800 [114.45 sec]: learning rate : 0.000250 loss : 0.332021 +[05:39:21.139] iteration 34900 [162.71 sec]: learning rate : 0.000250 loss : 0.354505 +[05:40:08.693] iteration 35000 [210.26 sec]: learning rate : 0.000250 loss : 0.362793 +[05:40:57.010] iteration 35100 [258.58 sec]: learning rate : 0.000250 loss : 0.312488 +[05:41:14.120] Epoch 60 Evaluation: +[05:45:40.071] average MSE: 0.0718720506740687 average PSNR: 24.44970391136503 average SSIM: 0.6048845333253218 +[05:46:10.704] iteration 35200 [30.61 sec]: learning rate : 0.000250 loss : 0.361233 +[05:46:58.118] iteration 35300 [78.02 sec]: learning rate : 0.000250 loss : 0.385454 +[05:47:45.685] iteration 35400 [125.59 sec]: learning rate : 0.000250 loss : 0.328580 +[05:48:33.721] iteration 35500 [173.63 sec]: learning rate : 0.000250 loss : 0.401260 +[05:49:21.158] iteration 35600 [221.06 sec]: learning rate : 0.000250 loss : 0.291801 +[05:50:08.652] iteration 35700 [268.56 sec]: learning rate : 0.000250 loss : 0.272833 +[05:50:14.338] Epoch 61 Evaluation: +[05:54:30.256] average MSE: 0.052191951441807015 average PSNR: 25.861430136656494 average SSIM: 0.6669338097933292 +[05:55:12.132] iteration 35800 [41.85 sec]: learning rate : 0.000250 loss : 0.345302 +[05:56:00.067] iteration 35900 [89.79 sec]: learning rate : 0.000250 loss : 0.298163 +[05:56:48.005] iteration 36000 [137.73 sec]: learning rate : 0.000250 loss : 0.353797 +[05:57:35.891] iteration 36100 [185.61 sec]: learning rate : 0.000250 loss : 0.435048 +[05:58:23.422] iteration 36200 [233.14 sec]: learning rate : 0.000250 loss : 0.314876 +[05:59:05.228] Epoch 62 Evaluation: +[06:03:29.283] average MSE: 0.05808440744132604 average PSNR: 25.394005784024152 average SSIM: 0.6414520079923256 +[06:03:35.145] iteration 36300 [5.84 sec]: learning rate : 0.000250 loss : 0.292895 +[06:04:22.509] iteration 36400 [53.20 sec]: learning rate : 0.000250 loss : 0.274288 +[06:05:10.095] iteration 36500 [100.79 sec]: learning rate : 0.000250 loss : 0.318340 +[06:05:57.647] iteration 36600 [148.34 sec]: learning rate : 0.000250 loss : 0.309478 +[06:06:45.138] iteration 36700 [195.83 sec]: learning rate : 0.000250 loss : 0.346767 +[06:07:33.129] iteration 36800 [243.82 sec]: learning rate : 0.000250 loss : 0.267905 +[06:08:03.553] Epoch 63 Evaluation: +[06:12:20.230] average MSE: 0.06742440793534692 average PSNR: 24.730692589554483 average SSIM: 0.6158672010053047 +[06:12:37.465] iteration 36900 [17.21 sec]: learning rate : 0.000250 loss : 0.399104 +[06:13:25.531] iteration 37000 [65.28 sec]: learning rate : 0.000250 loss : 0.344397 +[06:14:12.993] iteration 37100 [112.74 sec]: learning rate : 0.000250 loss : 0.278738 +[06:15:00.400] iteration 37200 [160.15 sec]: learning rate : 0.000250 loss : 0.270733 +[06:15:48.254] iteration 37300 [208.00 sec]: learning rate : 0.000250 loss : 0.262502 +[06:16:35.795] iteration 37400 [255.54 sec]: learning rate : 0.000250 loss : 0.269144 +[06:16:54.824] Epoch 64 Evaluation: +[06:21:12.832] average MSE: 0.06513628782168376 average PSNR: 24.889361439235977 average SSIM: 0.6227900842702007 +[06:21:41.959] iteration 37500 [29.11 sec]: learning rate : 0.000250 loss : 0.291578 +[06:22:29.756] iteration 37600 [76.90 sec]: learning rate : 0.000250 loss : 0.257821 +[06:23:17.347] iteration 37700 [124.49 sec]: learning rate : 0.000250 loss : 0.405119 +[06:24:05.396] iteration 37800 [172.54 sec]: learning rate : 0.000250 loss : 0.372282 +[06:24:52.879] iteration 37900 [220.02 sec]: learning rate : 0.000250 loss : 0.388035 +[06:25:40.365] iteration 38000 [267.51 sec]: learning rate : 0.000250 loss : 0.353214 +[06:25:48.625] Epoch 65 Evaluation: +[06:30:10.671] average MSE: 0.08070252778960595 average PSNR: 23.936727894378706 average SSIM: 0.5965869198102938 +[06:30:50.816] iteration 38100 [40.12 sec]: learning rate : 0.000250 loss : 0.367582 +[06:31:38.537] iteration 38200 [87.84 sec]: learning rate : 0.000250 loss : 0.369272 +[06:32:26.073] iteration 38300 [135.38 sec]: learning rate : 0.000250 loss : 0.343400 +[06:33:13.807] iteration 38400 [183.14 sec]: learning rate : 0.000250 loss : 0.247126 +[06:34:02.208] iteration 38500 [231.51 sec]: learning rate : 0.000250 loss : 0.301475 +[06:34:45.804] Epoch 66 Evaluation: +[06:39:00.572] average MSE: 0.07721247352094282 average PSNR: 24.131531106307875 average SSIM: 0.5998245761659495 +[06:39:04.567] iteration 38600 [3.97 sec]: learning rate : 0.000250 loss : 0.380138 +[06:39:52.775] iteration 38700 [52.19 sec]: learning rate : 0.000250 loss : 0.283052 +[06:40:40.288] iteration 38800 [99.69 sec]: learning rate : 0.000250 loss : 0.350242 +[06:41:28.931] iteration 38900 [148.34 sec]: learning rate : 0.000250 loss : 0.353741 +[06:42:16.415] iteration 39000 [195.82 sec]: learning rate : 0.000250 loss : 0.319133 +[06:43:03.789] iteration 39100 [243.19 sec]: learning rate : 0.000250 loss : 0.333836 +[06:43:36.129] Epoch 67 Evaluation: +[06:47:53.661] average MSE: 0.06510050255677274 average PSNR: 24.88396936022969 average SSIM: 0.6316416027503681 +[06:48:09.046] iteration 39200 [15.36 sec]: learning rate : 0.000250 loss : 0.363969 +[06:48:57.329] iteration 39300 [63.65 sec]: learning rate : 0.000250 loss : 0.424155 +[06:49:44.897] iteration 39400 [111.21 sec]: learning rate : 0.000250 loss : 0.376855 +[06:50:32.556] iteration 39500 [158.87 sec]: learning rate : 0.000250 loss : 0.270399 +[06:51:20.207] iteration 39600 [206.52 sec]: learning rate : 0.000250 loss : 0.302796 +[06:52:08.584] iteration 39700 [254.90 sec]: learning rate : 0.000250 loss : 0.295862 +[06:52:29.630] Epoch 68 Evaluation: +[06:56:51.592] average MSE: 0.062187608879119706 average PSNR: 25.091818182013782 average SSIM: 0.62795509208811 +[06:57:18.309] iteration 39800 [26.69 sec]: learning rate : 0.000250 loss : 0.301345 +[06:58:07.030] iteration 39900 [75.42 sec]: learning rate : 0.000250 loss : 0.310914 +[06:58:54.567] iteration 40000 [122.95 sec]: learning rate : 0.000063 loss : 0.339259 +[06:58:54.729] save model to model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time/iter_40000.pth +[06:59:42.248] iteration 40100 [170.63 sec]: learning rate : 0.000125 loss : 0.340007 +[07:00:29.797] iteration 40200 [218.18 sec]: learning rate : 0.000125 loss : 0.270873 +[07:01:17.365] iteration 40300 [265.75 sec]: learning rate : 0.000125 loss : 0.312253 +[07:01:26.885] Epoch 69 Evaluation: +[07:05:56.556] average MSE: 0.06413115129361176 average PSNR: 24.95120741857692 average SSIM: 0.6255464384002399 +[07:06:34.836] iteration 40400 [38.26 sec]: learning rate : 0.000125 loss : 0.390786 +[07:07:22.250] iteration 40500 [85.67 sec]: learning rate : 0.000125 loss : 0.291553 +[07:08:09.911] iteration 40600 [133.33 sec]: learning rate : 0.000125 loss : 0.331276 +[07:08:58.162] iteration 40700 [181.58 sec]: learning rate : 0.000125 loss : 0.342048 +[07:09:45.972] iteration 40800 [229.39 sec]: learning rate : 0.000125 loss : 0.289746 +[07:10:31.573] Epoch 70 Evaluation: +[07:14:47.370] average MSE: 0.06942070010491243 average PSNR: 24.60208973087015 average SSIM: 0.6139522706327841 +[07:14:49.452] iteration 40900 [2.06 sec]: learning rate : 0.000125 loss : 0.384963 +[07:15:37.302] iteration 41000 [49.93 sec]: learning rate : 0.000125 loss : 0.298548 +[07:16:25.412] iteration 41100 [98.02 sec]: learning rate : 0.000125 loss : 0.260838 +[07:17:13.002] iteration 41200 [145.61 sec]: learning rate : 0.000125 loss : 0.371027 +[07:18:00.561] iteration 41300 [193.17 sec]: learning rate : 0.000125 loss : 0.389461 +[07:18:48.451] iteration 41400 [241.06 sec]: learning rate : 0.000125 loss : 0.334064 +[07:19:22.771] Epoch 71 Evaluation: +[07:23:39.861] average MSE: 0.0660282726323414 average PSNR: 24.82363773943539 average SSIM: 0.619079195249868 +[07:23:53.403] iteration 41500 [13.52 sec]: learning rate : 0.000125 loss : 0.321860 +[07:24:40.865] iteration 41600 [60.98 sec]: learning rate : 0.000125 loss : 0.327676 +[07:25:29.003] iteration 41700 [109.12 sec]: learning rate : 0.000125 loss : 0.343285 +[07:26:16.982] iteration 41800 [157.10 sec]: learning rate : 0.000125 loss : 0.301849 +[07:27:04.513] iteration 41900 [204.63 sec]: learning rate : 0.000125 loss : 0.345020 +[07:27:52.685] iteration 42000 [252.80 sec]: learning rate : 0.000125 loss : 0.340926 +[07:28:15.505] Epoch 72 Evaluation: +[07:32:45.601] average MSE: 0.0790585924286511 average PSNR: 24.025814065209662 average SSIM: 0.5989154784869691 +[07:33:10.692] iteration 42100 [25.07 sec]: learning rate : 0.000125 loss : 0.255015 +[07:33:59.030] iteration 42200 [73.41 sec]: learning rate : 0.000125 loss : 0.319256 +[07:34:46.573] iteration 42300 [120.95 sec]: learning rate : 0.000125 loss : 0.377616 +[07:35:34.090] iteration 42400 [168.47 sec]: learning rate : 0.000125 loss : 0.286219 +[07:36:21.493] iteration 42500 [215.87 sec]: learning rate : 0.000125 loss : 0.344384 +[07:37:09.286] iteration 42600 [263.66 sec]: learning rate : 0.000125 loss : 0.365480 +[07:37:20.659] Epoch 73 Evaluation: +[07:41:35.483] average MSE: 0.07165058388313143 average PSNR: 24.460224142122243 average SSIM: 0.609400312102646 +[07:42:11.815] iteration 42700 [36.31 sec]: learning rate : 0.000125 loss : 0.263945 +[07:42:59.695] iteration 42800 [84.19 sec]: learning rate : 0.000125 loss : 0.319310 +[07:43:47.352] iteration 42900 [131.85 sec]: learning rate : 0.000125 loss : 0.340889 +[07:44:34.808] iteration 43000 [179.30 sec]: learning rate : 0.000125 loss : 0.270796 +[07:45:22.265] iteration 43100 [226.76 sec]: learning rate : 0.000125 loss : 0.350575 +[07:46:10.044] iteration 43200 [274.54 sec]: learning rate : 0.000125 loss : 0.295344 +[07:46:10.079] Epoch 74 Evaluation: +[07:50:25.581] average MSE: 0.07440187004198683 average PSNR: 24.293118713548374 average SSIM: 0.604319072292422 +[07:51:13.165] iteration 43300 [47.56 sec]: learning rate : 0.000125 loss : 0.339807 +[07:52:01.058] iteration 43400 [95.45 sec]: learning rate : 0.000125 loss : 0.386039 +[07:52:48.534] iteration 43500 [142.93 sec]: learning rate : 0.000125 loss : 0.206760 +[07:53:36.477] iteration 43600 [190.87 sec]: learning rate : 0.000125 loss : 0.288029 +[07:54:23.916] iteration 43700 [238.31 sec]: learning rate : 0.000125 loss : 0.300101 +[07:54:59.936] Epoch 75 Evaluation: +[07:59:17.582] average MSE: 0.0739904051873609 average PSNR: 24.318925794207676 average SSIM: 0.6063148680836362 +[07:59:29.360] iteration 43800 [11.76 sec]: learning rate : 0.000125 loss : 0.268623 +[08:00:16.901] iteration 43900 [59.30 sec]: learning rate : 0.000125 loss : 0.353424 +[08:01:04.407] iteration 44000 [106.80 sec]: learning rate : 0.000125 loss : 0.388591 +[08:01:51.836] iteration 44100 [154.23 sec]: learning rate : 0.000125 loss : 0.371435 +[08:02:39.394] iteration 44200 [201.79 sec]: learning rate : 0.000125 loss : 0.355219 +[08:03:27.359] iteration 44300 [249.76 sec]: learning rate : 0.000125 loss : 0.241621 +[08:03:52.158] Epoch 76 Evaluation: +[08:08:17.307] average MSE: 0.07020726158097623 average PSNR: 24.552396220740974 average SSIM: 0.6165381695522846 +[08:08:40.280] iteration 44400 [22.95 sec]: learning rate : 0.000125 loss : 0.319759 +[08:09:27.908] iteration 44500 [70.58 sec]: learning rate : 0.000125 loss : 0.324212 +[08:10:16.964] iteration 44600 [119.64 sec]: learning rate : 0.000125 loss : 0.329064 +[08:11:05.041] iteration 44700 [167.71 sec]: learning rate : 0.000125 loss : 0.295910 +[08:11:52.591] iteration 44800 [215.26 sec]: learning rate : 0.000125 loss : 0.385841 +[08:12:40.179] iteration 44900 [262.85 sec]: learning rate : 0.000125 loss : 0.361076 +[08:12:53.608] Epoch 77 Evaluation: +[08:17:09.927] average MSE: 0.07636314051123352 average PSNR: 24.176774841169458 average SSIM: 0.6014916713965103 +[08:17:44.354] iteration 45000 [34.40 sec]: learning rate : 0.000125 loss : 0.339866 +[08:18:32.567] iteration 45100 [82.62 sec]: learning rate : 0.000125 loss : 0.239378 +[08:19:20.085] iteration 45200 [130.14 sec]: learning rate : 0.000125 loss : 0.369971 +[08:20:07.721] iteration 45300 [177.77 sec]: learning rate : 0.000125 loss : 0.329738 +[08:20:55.363] iteration 45400 [225.41 sec]: learning rate : 0.000125 loss : 0.242381 +[08:21:43.437] iteration 45500 [273.49 sec]: learning rate : 0.000125 loss : 0.325474 +[08:21:45.348] Epoch 78 Evaluation: +[08:26:14.383] average MSE: 0.07067349881510991 average PSNR: 24.520207983321068 average SSIM: 0.6108763387638424 +[08:27:00.880] iteration 45600 [46.47 sec]: learning rate : 0.000125 loss : 0.313349 +[08:27:48.358] iteration 45700 [93.95 sec]: learning rate : 0.000125 loss : 0.340212 +[08:28:36.230] iteration 45800 [141.83 sec]: learning rate : 0.000125 loss : 0.291836 +[08:29:23.773] iteration 45900 [189.37 sec]: learning rate : 0.000125 loss : 0.364162 +[08:30:11.186] iteration 46000 [236.78 sec]: learning rate : 0.000125 loss : 0.339022 +[08:30:49.189] Epoch 79 Evaluation: +[08:35:06.349] average MSE: 0.07769544096683453 average PSNR: 24.100285494279298 average SSIM: 0.5990072029756212 +[08:35:16.122] iteration 46100 [9.75 sec]: learning rate : 0.000125 loss : 0.306022 +[08:36:03.693] iteration 46200 [57.32 sec]: learning rate : 0.000125 loss : 0.395575 +[08:36:51.152] iteration 46300 [104.78 sec]: learning rate : 0.000125 loss : 0.282960 +[08:37:38.690] iteration 46400 [152.32 sec]: learning rate : 0.000125 loss : 0.317512 +[08:38:26.584] iteration 46500 [200.21 sec]: learning rate : 0.000125 loss : 0.317133 +[08:39:13.970] iteration 46600 [247.60 sec]: learning rate : 0.000125 loss : 0.309374 +[08:39:40.602] Epoch 80 Evaluation: +[08:43:58.173] average MSE: 0.08738629685066361 average PSNR: 23.58607703988281 average SSIM: 0.5961553204066047 +[08:44:19.229] iteration 46700 [21.03 sec]: learning rate : 0.000125 loss : 0.309051 +[08:45:06.849] iteration 46800 [68.65 sec]: learning rate : 0.000125 loss : 0.433625 +[08:45:54.315] iteration 46900 [116.12 sec]: learning rate : 0.000125 loss : 0.325229 +[08:46:42.446] iteration 47000 [164.25 sec]: learning rate : 0.000125 loss : 0.353794 +[08:47:30.046] iteration 47100 [211.85 sec]: learning rate : 0.000125 loss : 0.328949 +[08:48:18.209] iteration 47200 [260.01 sec]: learning rate : 0.000125 loss : 0.835133 +[08:48:33.411] Epoch 81 Evaluation: +[08:52:57.475] average MSE: 0.0752895187245366 average PSNR: 24.240194559002052 average SSIM: 0.6017546160604677 +[08:53:30.014] iteration 47300 [32.52 sec]: learning rate : 0.000125 loss : 0.325751 +[08:54:17.412] iteration 47400 [79.91 sec]: learning rate : 0.000125 loss : 0.266491 +[08:55:05.223] iteration 47500 [127.73 sec]: learning rate : 0.000125 loss : 0.320364 +[08:55:53.283] iteration 47600 [175.79 sec]: learning rate : 0.000125 loss : 0.284074 +[08:56:40.796] iteration 47700 [223.30 sec]: learning rate : 0.000125 loss : 0.341040 +[08:57:28.425] iteration 47800 [270.93 sec]: learning rate : 0.000125 loss : 0.304525 +[08:57:32.229] Epoch 82 Evaluation: +[09:01:51.233] average MSE: 0.0831004544713009 average PSNR: 23.8061683686311 average SSIM: 0.5937563085645914 +[09:02:35.222] iteration 47900 [43.97 sec]: learning rate : 0.000125 loss : 0.334100 +[09:03:23.105] iteration 48000 [91.85 sec]: learning rate : 0.000125 loss : 0.414731 +[09:04:10.686] iteration 48100 [139.43 sec]: learning rate : 0.000125 loss : 0.371946 +[09:04:58.630] iteration 48200 [187.37 sec]: learning rate : 0.000125 loss : 0.383250 +[09:05:46.125] iteration 48300 [234.87 sec]: learning rate : 0.000125 loss : 0.277362 +[09:06:26.558] Epoch 83 Evaluation: +[09:10:43.251] average MSE: 0.08137460085730035 average PSNR: 23.898179965269527 average SSIM: 0.5965794374695196 +[09:10:51.094] iteration 48400 [7.82 sec]: learning rate : 0.000125 loss : 0.334827 +[09:11:39.009] iteration 48500 [55.74 sec]: learning rate : 0.000125 loss : 0.370452 +[09:12:26.704] iteration 48600 [103.43 sec]: learning rate : 0.000125 loss : 0.344519 +[09:13:14.353] iteration 48700 [151.10 sec]: learning rate : 0.000125 loss : 0.329450 +[09:14:01.774] iteration 48800 [198.50 sec]: learning rate : 0.000125 loss : 0.315944 +[09:14:49.471] iteration 48900 [246.20 sec]: learning rate : 0.000125 loss : 0.304532 +[09:15:17.902] Epoch 84 Evaluation: +[09:19:38.137] average MSE: 0.08367518703017722 average PSNR: 23.776908617943935 average SSIM: 0.5936484813864454 +[09:19:57.411] iteration 49000 [19.25 sec]: learning rate : 0.000125 loss : 0.300535 +[09:20:44.792] iteration 49100 [66.63 sec]: learning rate : 0.000125 loss : 1.183732 +[09:21:32.394] iteration 49200 [114.23 sec]: learning rate : 0.000125 loss : 0.383716 +[09:22:19.880] iteration 49300 [161.72 sec]: learning rate : 0.000125 loss : 0.377831 +[09:23:08.096] iteration 49400 [209.94 sec]: learning rate : 0.000125 loss : 0.304319 +[09:23:55.580] iteration 49500 [257.42 sec]: learning rate : 0.000125 loss : 0.269392 +[09:24:12.637] Epoch 85 Evaluation: +[09:28:36.908] average MSE: 0.0954826836270453 average PSNR: 23.204509082915404 average SSIM: 0.5949354203797416 +[09:29:07.968] iteration 49600 [31.04 sec]: learning rate : 0.000125 loss : 0.421092 +[09:29:55.635] iteration 49700 [78.71 sec]: learning rate : 0.000125 loss : 0.381260 +[09:30:43.178] iteration 49800 [126.25 sec]: learning rate : 0.000125 loss : 0.265481 +[09:31:30.559] iteration 49900 [173.63 sec]: learning rate : 0.000125 loss : 0.335861 +[09:32:18.163] iteration 50000 [221.23 sec]: learning rate : 0.000125 loss : 0.284460 +[09:33:05.810] iteration 50100 [268.88 sec]: learning rate : 0.000125 loss : 0.257447 +[09:33:11.519] Epoch 86 Evaluation: +[09:37:32.927] average MSE: 0.08343649310753878 average PSNR: 23.787926321233968 average SSIM: 0.5946214068296477 +[09:38:14.814] iteration 50200 [41.86 sec]: learning rate : 0.000125 loss : 0.306642 +[09:39:02.369] iteration 50300 [89.42 sec]: learning rate : 0.000125 loss : 0.292300 +[09:39:50.071] iteration 50400 [137.12 sec]: learning rate : 0.000125 loss : 0.325991 +[09:40:37.409] iteration 50500 [184.46 sec]: learning rate : 0.000125 loss : 0.418764 +[09:41:25.987] iteration 50600 [233.04 sec]: learning rate : 0.000125 loss : 0.309058 +[09:42:07.819] Epoch 87 Evaluation: +[09:46:30.542] average MSE: 0.07281371423409455 average PSNR: 24.39001410924535 average SSIM: 0.6075491669504032 +[09:46:36.592] iteration 50700 [6.03 sec]: learning rate : 0.000125 loss : 0.313022 +[09:47:24.358] iteration 50800 [53.79 sec]: learning rate : 0.000125 loss : 0.297689 +[09:48:11.831] iteration 50900 [101.27 sec]: learning rate : 0.000125 loss : 0.391164 +[09:48:59.741] iteration 51000 [149.18 sec]: learning rate : 0.000125 loss : 0.307399 +[09:49:47.264] iteration 51100 [196.70 sec]: learning rate : 0.000125 loss : 0.379967 +[09:50:34.748] iteration 51200 [244.18 sec]: learning rate : 0.000125 loss : 0.245404 +[09:51:05.522] Epoch 88 Evaluation: +[09:55:20.381] average MSE: 0.07683831404445064 average PSNR: 24.15099422916315 average SSIM: 0.5972026221232537 +[09:55:37.698] iteration 51300 [17.30 sec]: learning rate : 0.000125 loss : 0.349884 +[09:56:26.215] iteration 51400 [65.81 sec]: learning rate : 0.000125 loss : 0.343256 +[09:57:13.747] iteration 51500 [113.34 sec]: learning rate : 0.000125 loss : 0.283745 +[09:58:01.157] iteration 51600 [160.75 sec]: learning rate : 0.000125 loss : 0.292774 +[09:58:48.735] iteration 51700 [208.33 sec]: learning rate : 0.000125 loss : 0.262744 +[09:59:36.747] iteration 51800 [256.34 sec]: learning rate : 0.000125 loss : 0.266995 +[09:59:55.771] Epoch 89 Evaluation: +[10:04:18.252] average MSE: 0.07461087525586216 average PSNR: 24.281411052320774 average SSIM: 0.5995854137947383 +[10:04:46.860] iteration 51900 [28.59 sec]: learning rate : 0.000125 loss : 0.350208 +[10:05:34.838] iteration 52000 [76.56 sec]: learning rate : 0.000125 loss : 0.280271 +[10:06:22.232] iteration 52100 [123.96 sec]: learning rate : 0.000125 loss : 0.366277 +[10:07:09.711] iteration 52200 [171.44 sec]: learning rate : 0.000125 loss : 0.439928 +[10:07:57.742] iteration 52300 [219.47 sec]: learning rate : 0.000125 loss : 0.298440 +[10:08:45.292] iteration 52400 [267.02 sec]: learning rate : 0.000125 loss : 0.331249 +[10:08:52.901] Epoch 90 Evaluation: +[10:13:19.550] average MSE: 0.08101266042631099 average PSNR: 23.919956573050264 average SSIM: 0.5908394983218115 +[10:13:59.698] iteration 52500 [40.12 sec]: learning rate : 0.000125 loss : 0.349183 +[10:14:47.190] iteration 52600 [87.62 sec]: learning rate : 0.000125 loss : 0.336831 +[10:15:34.723] iteration 52700 [135.15 sec]: learning rate : 0.000125 loss : 0.351824 +[10:16:22.313] iteration 52800 [182.74 sec]: learning rate : 0.000125 loss : 0.268585 +[10:17:10.172] iteration 52900 [230.62 sec]: learning rate : 0.000125 loss : 0.318247 +[10:17:54.034] Epoch 91 Evaluation: +[10:22:10.337] average MSE: 0.08662817210242758 average PSNR: 23.623512573990784 average SSIM: 0.5996092335940097 +[10:22:14.321] iteration 53000 [3.96 sec]: learning rate : 0.000125 loss : 0.349741 +[10:23:01.868] iteration 53100 [51.51 sec]: learning rate : 0.000125 loss : 0.285670 +[10:23:49.546] iteration 53200 [99.19 sec]: learning rate : 0.000125 loss : 0.328090 +[10:24:37.512] iteration 53300 [147.15 sec]: learning rate : 0.000125 loss : 0.339202 +[10:25:25.009] iteration 53400 [194.65 sec]: learning rate : 0.000125 loss : 0.377030 +[10:26:12.800] iteration 53500 [242.44 sec]: learning rate : 0.000125 loss : 0.340895 +[10:26:45.175] Epoch 92 Evaluation: +[10:31:01.256] average MSE: 0.08829979498062387 average PSNR: 23.53982600433472 average SSIM: 0.5953653178291078 +[10:31:16.593] iteration 53600 [15.31 sec]: learning rate : 0.000125 loss : 0.340733 +[10:32:04.190] iteration 53700 [62.91 sec]: learning rate : 0.000125 loss : 0.384596 +[10:32:51.567] iteration 53800 [110.29 sec]: learning rate : 0.000125 loss : 0.395905 +[10:33:39.412] iteration 53900 [158.13 sec]: learning rate : 0.000125 loss : 0.309149 +[10:34:27.014] iteration 54000 [205.73 sec]: learning rate : 0.000125 loss : 0.319747 +[10:35:14.519] iteration 54100 [253.24 sec]: learning rate : 0.000125 loss : 0.345495 +[10:35:36.463] Epoch 93 Evaluation: +[10:39:58.267] average MSE: 0.07188465642659932 average PSNR: 24.446615825057528 average SSIM: 0.6128023689054171 +[10:40:25.196] iteration 54200 [26.91 sec]: learning rate : 0.000125 loss : 0.332920 +[10:41:13.657] iteration 54300 [75.37 sec]: learning rate : 0.000125 loss : 0.278371 +[10:42:01.510] iteration 54400 [123.22 sec]: learning rate : 0.000125 loss : 0.372024 +[10:42:49.009] iteration 54500 [170.72 sec]: learning rate : 0.000125 loss : 0.267475 +[10:43:36.415] iteration 54600 [218.12 sec]: learning rate : 0.000125 loss : 0.261975 +[10:44:23.942] iteration 54700 [265.66 sec]: learning rate : 0.000125 loss : 0.324894 +[10:44:33.442] Epoch 94 Evaluation: +[10:48:53.528] average MSE: 0.08962230410037393 average PSNR: 23.47775684664411 average SSIM: 0.5912391094574524 +[10:49:31.779] iteration 54800 [38.23 sec]: learning rate : 0.000125 loss : 0.400610 +[10:50:19.272] iteration 54900 [85.72 sec]: learning rate : 0.000125 loss : 0.246142 +[10:51:06.874] iteration 55000 [133.32 sec]: learning rate : 0.000125 loss : 0.319254 +[10:51:54.510] iteration 55100 [180.96 sec]: learning rate : 0.000125 loss : 0.259914 +[10:52:42.388] iteration 55200 [228.84 sec]: learning rate : 0.000125 loss : 0.266727 +[10:53:28.242] Epoch 95 Evaluation: +[10:57:45.042] average MSE: 0.08325932522010171 average PSNR: 23.798519161263545 average SSIM: 0.5933919242767993 +[10:57:47.116] iteration 55300 [2.05 sec]: learning rate : 0.000125 loss : 0.373670 +[10:58:34.482] iteration 55400 [49.42 sec]: learning rate : 0.000125 loss : 0.297131 +[10:59:22.047] iteration 55500 [96.98 sec]: learning rate : 0.000125 loss : 0.263704 +[11:00:10.244] iteration 55600 [145.18 sec]: learning rate : 0.000125 loss : 0.349501 +[11:00:57.769] iteration 55700 [192.70 sec]: learning rate : 0.000125 loss : 0.359135 +[11:01:45.365] iteration 55800 [240.30 sec]: learning rate : 0.000125 loss : 0.328445 +[11:02:19.600] Epoch 96 Evaluation: +[11:06:42.100] average MSE: 0.09036770734402523 average PSNR: 23.44126482207879 average SSIM: 0.594293297173246 +[11:06:55.699] iteration 55900 [13.58 sec]: learning rate : 0.000125 loss : 0.346278 +[11:07:43.171] iteration 56000 [61.05 sec]: learning rate : 0.000125 loss : 0.312571 +[11:08:30.624] iteration 56100 [108.50 sec]: learning rate : 0.000125 loss : 0.300095 +[11:09:18.449] iteration 56200 [156.33 sec]: learning rate : 0.000125 loss : 0.311388 +[11:10:05.803] iteration 56300 [203.68 sec]: learning rate : 0.000125 loss : 0.339320 +[11:10:53.593] iteration 56400 [251.47 sec]: learning rate : 0.000125 loss : 0.266158 +[11:11:16.327] Epoch 97 Evaluation: +[11:15:42.365] average MSE: 0.0882042444462436 average PSNR: 23.545821668808685 average SSIM: 0.593278173870621 +[11:16:07.270] iteration 56500 [24.88 sec]: learning rate : 0.000125 loss : 0.297876 +[11:16:54.973] iteration 56600 [72.58 sec]: learning rate : 0.000125 loss : 0.311174 +[11:17:43.011] iteration 56700 [120.62 sec]: learning rate : 0.000125 loss : 0.363222 +[11:18:30.880] iteration 56800 [168.49 sec]: learning rate : 0.000125 loss : 0.306627 +[11:19:18.364] iteration 56900 [215.97 sec]: learning rate : 0.000125 loss : 0.330476 +[11:20:06.389] iteration 57000 [264.00 sec]: learning rate : 0.000125 loss : 0.302119 +[11:20:17.808] Epoch 98 Evaluation: +[11:24:33.229] average MSE: 0.08000316006655726 average PSNR: 23.973543939259393 average SSIM: 0.5973941513190921 +[11:25:09.585] iteration 57100 [36.33 sec]: learning rate : 0.000125 loss : 0.318548 +[11:25:58.101] iteration 57200 [84.85 sec]: learning rate : 0.000125 loss : 0.338128 +[11:26:45.495] iteration 57300 [132.24 sec]: learning rate : 0.000125 loss : 0.314258 +[11:27:32.995] iteration 57400 [179.74 sec]: learning rate : 0.000125 loss : 0.292131 +[11:28:20.596] iteration 57500 [227.34 sec]: learning rate : 0.000125 loss : 0.399974 +[11:29:08.093] iteration 57600 [274.84 sec]: learning rate : 0.000125 loss : 0.302422 +[11:29:08.138] Epoch 99 Evaluation: +[11:33:25.294] average MSE: 0.08960538328571613 average PSNR: 23.47634184845135 average SSIM: 0.5912439362918286 +[11:34:13.473] iteration 57700 [48.16 sec]: learning rate : 0.000125 loss : 0.335801 +[11:35:00.977] iteration 57800 [95.66 sec]: learning rate : 0.000125 loss : 0.312464 +[11:35:48.726] iteration 57900 [143.41 sec]: learning rate : 0.000125 loss : 0.235231 +[11:36:36.218] iteration 58000 [190.90 sec]: learning rate : 0.000125 loss : 0.300805 +[11:37:23.982] iteration 58100 [238.67 sec]: learning rate : 0.000125 loss : 0.303324 +[11:38:00.097] Epoch 100 Evaluation: +[11:42:21.832] average MSE: 0.08489559810309136 average PSNR: 23.71312353910849 average SSIM: 0.5921138116541018 +[11:42:33.398] iteration 58200 [11.54 sec]: learning rate : 0.000125 loss : 0.253442 +[11:43:21.045] iteration 58300 [59.19 sec]: learning rate : 0.000125 loss : 0.400442 +[11:44:08.518] iteration 58400 [106.66 sec]: learning rate : 0.000125 loss : 0.402006 +[11:44:55.995] iteration 58500 [154.14 sec]: learning rate : 0.000125 loss : 0.371785 +[11:45:43.498] iteration 58600 [201.64 sec]: learning rate : 0.000125 loss : 0.416298 +[11:46:30.857] iteration 58700 [249.00 sec]: learning rate : 0.000125 loss : 0.223080 +[11:46:55.583] Epoch 101 Evaluation: +[11:51:20.293] average MSE: 0.07394751171271485 average PSNR: 24.321054314763376 average SSIM: 0.6085459291305055 +[11:51:43.205] iteration 58800 [22.89 sec]: learning rate : 0.000125 loss : 0.339097 +[11:52:30.926] iteration 58900 [70.61 sec]: learning rate : 0.000125 loss : 0.369463 +[11:53:18.435] iteration 59000 [118.12 sec]: learning rate : 0.000125 loss : 0.349679 +[11:54:06.825] iteration 59100 [166.51 sec]: learning rate : 0.000125 loss : 0.331333 +[11:54:54.337] iteration 59200 [214.02 sec]: learning rate : 0.000125 loss : 0.353474 +[11:55:41.827] iteration 59300 [261.51 sec]: learning rate : 0.000125 loss : 0.348328 +[11:55:55.435] Epoch 102 Evaluation: +[12:00:12.448] average MSE: 0.0711032659304082 average PSNR: 24.491205255758786 average SSIM: 0.6168202651027024 +[12:00:46.866] iteration 59400 [34.40 sec]: learning rate : 0.000125 loss : 0.337879 +[12:01:34.461] iteration 59500 [81.99 sec]: learning rate : 0.000125 loss : 0.276778 +[12:02:22.276] iteration 59600 [129.81 sec]: learning rate : 0.000125 loss : 0.396494 +[12:03:09.883] iteration 59700 [177.41 sec]: learning rate : 0.000125 loss : 0.323576 +[12:03:57.740] iteration 59800 [225.27 sec]: learning rate : 0.000125 loss : 0.253261 +[12:04:45.646] iteration 59900 [273.17 sec]: learning rate : 0.000125 loss : 0.331106 +[12:04:47.547] Epoch 103 Evaluation: +[12:09:14.743] average MSE: 0.07029699206523851 average PSNR: 24.544327935506573 average SSIM: 0.6175608333609659 +[12:10:00.716] iteration 60000 [45.95 sec]: learning rate : 0.000031 loss : 0.323085 +[12:10:00.881] save model to model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time/iter_60000.pth +[12:10:49.019] iteration 60100 [94.25 sec]: learning rate : 0.000063 loss : 0.334875 +[12:11:37.001] iteration 60200 [142.23 sec]: learning rate : 0.000063 loss : 0.315821 +[12:12:24.849] iteration 60300 [190.08 sec]: learning rate : 0.000063 loss : 0.339333 +[12:13:12.424] iteration 60400 [237.66 sec]: learning rate : 0.000063 loss : 0.371457 +[12:13:50.672] Epoch 104 Evaluation: +[12:18:20.860] average MSE: 0.0729800885899015 average PSNR: 24.376607820269392 average SSIM: 0.6124787765380044 +[12:18:30.910] iteration 60500 [10.04 sec]: learning rate : 0.000063 loss : 0.388918 +[12:19:18.292] iteration 60600 [57.41 sec]: learning rate : 0.000063 loss : 0.370603 +[12:20:05.794] iteration 60700 [104.91 sec]: learning rate : 0.000063 loss : 0.303919 +[12:20:53.230] iteration 60800 [152.35 sec]: learning rate : 0.000063 loss : 0.322818 +[12:21:40.752] iteration 60900 [199.87 sec]: learning rate : 0.000063 loss : 0.303971 +[12:22:28.252] iteration 61000 [247.37 sec]: learning rate : 0.000063 loss : 0.341269 +[12:22:54.794] Epoch 105 Evaluation: +[12:27:16.696] average MSE: 0.08538061253248394 average PSNR: 23.688302462216964 average SSIM: 0.600524663003923 +[12:27:37.874] iteration 61100 [21.16 sec]: learning rate : 0.000063 loss : 0.298198 +[12:28:25.363] iteration 61200 [68.64 sec]: learning rate : 0.000063 loss : 0.387441 +[12:29:12.831] iteration 61300 [116.11 sec]: learning rate : 0.000063 loss : 0.263751 +[12:30:00.472] iteration 61400 [163.78 sec]: learning rate : 0.000063 loss : 0.376646 +[12:30:48.263] iteration 61500 [211.54 sec]: learning rate : 0.000063 loss : 0.298285 +[12:31:35.926] iteration 61600 [259.21 sec]: learning rate : 0.000063 loss : 0.315746 +[12:31:51.146] Epoch 106 Evaluation: +[12:36:11.473] average MSE: 0.06550098987829504 average PSNR: 24.85752239894303 average SSIM: 0.6282042329077786 +[12:36:43.987] iteration 61700 [32.49 sec]: learning rate : 0.000063 loss : 0.300381 +[12:37:31.668] iteration 61800 [80.17 sec]: learning rate : 0.000063 loss : 0.318082 +[12:38:19.884] iteration 61900 [128.39 sec]: learning rate : 0.000063 loss : 0.275542 +[12:39:07.480] iteration 62000 [175.98 sec]: learning rate : 0.000063 loss : 0.290977 +[12:39:55.168] iteration 62100 [223.67 sec]: learning rate : 0.000063 loss : 0.308949 +[12:40:42.853] iteration 62200 [271.36 sec]: learning rate : 0.000063 loss : 0.287788 +[12:40:46.661] Epoch 107 Evaluation: +[12:45:18.407] average MSE: 0.09461158428237136 average PSNR: 23.242758982328127 average SSIM: 0.599422910950434 +[12:46:02.174] iteration 62300 [43.74 sec]: learning rate : 0.000063 loss : 0.328783 +[12:46:49.728] iteration 62400 [91.30 sec]: learning rate : 0.000063 loss : 0.382801 +[12:47:37.573] iteration 62500 [139.14 sec]: learning rate : 0.000063 loss : 0.392379 +[12:48:25.643] iteration 62600 [187.21 sec]: learning rate : 0.000063 loss : 0.360830 +[12:49:13.207] iteration 62700 [234.78 sec]: learning rate : 0.000063 loss : 0.287374 +[12:49:52.966] Epoch 108 Evaluation: +[12:54:06.143] average MSE: 0.08638830163131736 average PSNR: 23.637074845148387 average SSIM: 0.5980294100260882 +[12:54:13.892] iteration 62800 [7.73 sec]: learning rate : 0.000063 loss : 0.312808 +[12:55:02.246] iteration 62900 [56.08 sec]: learning rate : 0.000063 loss : 0.333171 +[12:55:50.163] iteration 63000 [104.00 sec]: learning rate : 0.000063 loss : 0.384421 +[12:56:37.726] iteration 63100 [151.56 sec]: learning rate : 0.000063 loss : 0.288161 +[12:57:25.371] iteration 63200 [199.21 sec]: learning rate : 0.000063 loss : 0.305969 +[12:58:12.954] iteration 63300 [246.79 sec]: learning rate : 0.000063 loss : 0.361659 +[12:58:41.477] Epoch 109 Evaluation: +[13:03:08.186] average MSE: 0.06097227438428715 average PSNR: 25.176292403133456 average SSIM: 0.6412027312802269 +[13:03:27.567] iteration 63400 [19.36 sec]: learning rate : 0.000063 loss : 0.334991 +[13:04:15.190] iteration 63500 [66.98 sec]: learning rate : 0.000063 loss : 0.352802 +[13:05:02.787] iteration 63600 [114.58 sec]: learning rate : 0.000063 loss : 0.329123 +[13:05:50.328] iteration 63700 [162.12 sec]: learning rate : 0.000063 loss : 0.392176 +[13:06:38.355] iteration 63800 [210.15 sec]: learning rate : 0.000063 loss : 0.318835 +[13:07:25.907] iteration 63900 [257.70 sec]: learning rate : 0.000063 loss : 0.280306 +[13:07:43.142] Epoch 110 Evaluation: +[13:12:09.464] average MSE: 0.063403886268526 average PSNR: 25.00088150382995 average SSIM: 0.6372116269539215 +[13:12:40.406] iteration 64000 [30.92 sec]: learning rate : 0.000063 loss : 0.367662 +[13:13:28.017] iteration 64100 [78.53 sec]: learning rate : 0.000063 loss : 0.358993 +[13:14:15.466] iteration 64200 [125.98 sec]: learning rate : 0.000063 loss : 0.285932 +[13:15:02.909] iteration 64300 [173.42 sec]: learning rate : 0.000063 loss : 0.361843 +[13:15:50.426] iteration 64400 [220.94 sec]: learning rate : 0.000063 loss : 0.273323 +[13:16:37.928] iteration 64500 [268.44 sec]: learning rate : 0.000063 loss : 0.240474 +[13:16:43.626] Epoch 111 Evaluation: +[13:21:03.259] average MSE: 0.0705188118860192 average PSNR: 24.530498752608956 average SSIM: 0.6191645533852789 +[13:21:45.229] iteration 64600 [41.95 sec]: learning rate : 0.000063 loss : 0.372090 +[13:22:32.651] iteration 64700 [89.37 sec]: learning rate : 0.000063 loss : 0.279000 +[13:23:20.000] iteration 64800 [136.72 sec]: learning rate : 0.000063 loss : 0.318477 +[13:24:07.416] iteration 64900 [184.13 sec]: learning rate : 0.000063 loss : 0.397427 +[13:24:55.160] iteration 65000 [231.88 sec]: learning rate : 0.000063 loss : 0.350716 +[13:25:36.814] Epoch 112 Evaluation: +[13:29:50.007] average MSE: 0.07131383603017358 average PSNR: 24.48185042579511 average SSIM: 0.615844200806439 +[13:29:55.911] iteration 65100 [5.88 sec]: learning rate : 0.000063 loss : 0.340639 +[13:30:43.822] iteration 65200 [53.79 sec]: learning rate : 0.000063 loss : 0.270015 +[13:31:31.428] iteration 65300 [101.40 sec]: learning rate : 0.000063 loss : 0.331296 +[13:32:18.934] iteration 65400 [148.90 sec]: learning rate : 0.000063 loss : 0.323244 +[13:33:06.866] iteration 65500 [196.84 sec]: learning rate : 0.000063 loss : 0.392508 +[13:33:54.172] iteration 65600 [244.14 sec]: learning rate : 0.000063 loss : 0.235795 +[13:34:24.598] Epoch 113 Evaluation: +[13:38:40.715] average MSE: 0.07868526585510453 average PSNR: 24.04791313351575 average SSIM: 0.601947598039404 +[13:38:57.945] iteration 65700 [17.21 sec]: learning rate : 0.000063 loss : 0.368422 +[13:39:45.426] iteration 65800 [64.69 sec]: learning rate : 0.000063 loss : 0.330301 +[13:40:32.723] iteration 65900 [111.98 sec]: learning rate : 0.000063 loss : 0.254707 +[13:41:20.328] iteration 66000 [159.59 sec]: learning rate : 0.000063 loss : 0.305405 +[13:42:07.755] iteration 66100 [207.02 sec]: learning rate : 0.000063 loss : 0.258965 +[13:42:55.476] iteration 66200 [254.74 sec]: learning rate : 0.000063 loss : 0.235328 +[13:43:14.416] Epoch 114 Evaluation: +[13:47:42.154] average MSE: 0.07240423767144809 average PSNR: 24.416675533281037 average SSIM: 0.6109903458292912 +[13:48:11.170] iteration 66300 [28.99 sec]: learning rate : 0.000063 loss : 0.314609 +[13:48:59.021] iteration 66400 [76.84 sec]: learning rate : 0.000063 loss : 0.249105 +[13:49:46.403] iteration 66500 [124.22 sec]: learning rate : 0.000063 loss : 0.366597 +[13:50:33.906] iteration 66600 [171.73 sec]: learning rate : 0.000063 loss : 0.406478 +[13:51:21.277] iteration 66700 [219.10 sec]: learning rate : 0.000063 loss : 0.382319 +[13:52:08.792] iteration 66800 [266.61 sec]: learning rate : 0.000063 loss : 0.358355 +[13:52:16.379] Epoch 115 Evaluation: +[13:56:37.856] average MSE: 0.07789457910949779 average PSNR: 24.09147537303014 average SSIM: 0.6074251039999986 +[13:57:18.109] iteration 66900 [40.23 sec]: learning rate : 0.000063 loss : 0.342043 +[13:58:05.588] iteration 67000 [87.71 sec]: learning rate : 0.000063 loss : 0.400135 +[13:58:53.169] iteration 67100 [135.29 sec]: learning rate : 0.000063 loss : 0.371707 +[13:59:40.650] iteration 67200 [182.77 sec]: learning rate : 0.000063 loss : 0.282010 +[14:00:28.027] iteration 67300 [230.15 sec]: learning rate : 0.000063 loss : 0.345061 +[14:01:12.157] Epoch 116 Evaluation: +[14:05:31.121] average MSE: 0.07234938581790466 average PSNR: 24.419619991090673 average SSIM: 0.6132982224821752 +[14:05:35.209] iteration 67400 [4.07 sec]: learning rate : 0.000063 loss : 0.357154 +[14:06:22.762] iteration 67500 [51.62 sec]: learning rate : 0.000063 loss : 0.292083 +[14:07:10.487] iteration 67600 [99.34 sec]: learning rate : 0.000063 loss : 0.335681 +[14:07:58.000] iteration 67700 [146.86 sec]: learning rate : 0.000063 loss : 0.305071 +[14:08:45.382] iteration 67800 [194.24 sec]: learning rate : 0.000063 loss : 0.325979 +[14:09:32.793] iteration 67900 [241.65 sec]: learning rate : 0.000063 loss : 0.322786 +[14:10:05.070] Epoch 117 Evaluation: +[14:14:18.828] average MSE: 0.08097623330406578 average PSNR: 23.920414755715772 average SSIM: 0.6035869739428585 +[14:14:34.156] iteration 68000 [15.31 sec]: learning rate : 0.000063 loss : 0.419651 +[14:15:21.559] iteration 68100 [62.71 sec]: learning rate : 0.000063 loss : 0.366931 +[14:16:09.145] iteration 68200 [110.29 sec]: learning rate : 0.000063 loss : 0.390721 +[14:16:56.656] iteration 68300 [157.81 sec]: learning rate : 0.000063 loss : 0.262659 +[14:17:44.130] iteration 68400 [205.28 sec]: learning rate : 0.000063 loss : 0.351248 +[14:18:31.983] iteration 68500 [253.13 sec]: learning rate : 0.000063 loss : 0.368340 +[14:18:52.879] Epoch 118 Evaluation: +[14:23:09.713] average MSE: 0.07821691023822408 average PSNR: 24.073235661702398 average SSIM: 0.6045809403369912 +[14:23:36.643] iteration 68600 [26.91 sec]: learning rate : 0.000063 loss : 0.285359 +[14:24:24.106] iteration 68700 [74.37 sec]: learning rate : 0.000063 loss : 0.269550 +[14:25:11.665] iteration 68800 [121.93 sec]: learning rate : 0.000063 loss : 0.306867 +[14:25:59.402] iteration 68900 [169.67 sec]: learning rate : 0.000063 loss : 0.311184 +[14:26:46.938] iteration 69000 [217.20 sec]: learning rate : 0.000063 loss : 0.267680 +[14:27:34.476] iteration 69100 [264.74 sec]: learning rate : 0.000063 loss : 0.357137 +[14:27:43.971] Epoch 119 Evaluation: +[14:32:09.305] average MSE: 0.07981414730281926 average PSNR: 23.984653320570207 average SSIM: 0.6025203504205735 +[14:32:47.308] iteration 69200 [37.98 sec]: learning rate : 0.000063 loss : 0.360729 +[14:33:35.034] iteration 69300 [85.70 sec]: learning rate : 0.000063 loss : 0.220874 +[14:34:22.429] iteration 69400 [133.10 sec]: learning rate : 0.000063 loss : 0.290021 +[14:35:09.752] iteration 69500 [180.42 sec]: learning rate : 0.000063 loss : 0.292177 +[14:35:57.195] iteration 69600 [227.87 sec]: learning rate : 0.000063 loss : 0.302724 +[14:36:42.689] Epoch 120 Evaluation: +[14:41:02.673] average MSE: 0.07346380295164939 average PSNR: 24.34988865353173 average SSIM: 0.6129778809022024 +[14:41:04.746] iteration 69700 [2.05 sec]: learning rate : 0.000063 loss : 0.402277 +[14:41:52.206] iteration 69800 [49.51 sec]: learning rate : 0.000063 loss : 0.299177 +[14:42:39.612] iteration 69900 [96.92 sec]: learning rate : 0.000063 loss : 0.294124 +[14:43:26.897] iteration 70000 [144.20 sec]: learning rate : 0.000063 loss : 0.322013 +[14:44:14.291] iteration 70100 [191.60 sec]: learning rate : 0.000063 loss : 0.339213 +[14:45:01.686] iteration 70200 [238.99 sec]: learning rate : 0.000063 loss : 0.327270 +[14:45:35.749] Epoch 121 Evaluation: +[14:49:54.534] average MSE: 0.07884781502100754 average PSNR: 24.038593092777706 average SSIM: 0.6042331790500424 +[14:50:07.993] iteration 70300 [13.44 sec]: learning rate : 0.000063 loss : 0.339251 +[14:50:55.495] iteration 70400 [60.94 sec]: learning rate : 0.000063 loss : 0.348172 +[14:51:42.959] iteration 70500 [108.40 sec]: learning rate : 0.000063 loss : 0.341529 +[14:52:30.391] iteration 70600 [155.83 sec]: learning rate : 0.000063 loss : 0.298157 +[14:53:17.775] iteration 70700 [203.22 sec]: learning rate : 0.000063 loss : 0.320096 +[14:54:05.064] iteration 70800 [250.51 sec]: learning rate : 0.000063 loss : 0.316725 +[14:54:27.855] Epoch 122 Evaluation: +[14:58:54.184] average MSE: 0.08557902870755822 average PSNR: 23.678863201937446 average SSIM: 0.5990836094514294 +[14:59:19.061] iteration 70900 [24.85 sec]: learning rate : 0.000063 loss : 0.286358 +[15:00:06.643] iteration 71000 [72.44 sec]: learning rate : 0.000063 loss : 0.323623 +[15:00:54.082] iteration 71100 [119.87 sec]: learning rate : 0.000063 loss : 0.323837 +[15:01:41.622] iteration 71200 [167.41 sec]: learning rate : 0.000063 loss : 0.312440 +[15:02:29.187] iteration 71300 [214.98 sec]: learning rate : 0.000063 loss : 0.276204 +[15:03:16.874] iteration 71400 [262.67 sec]: learning rate : 0.000063 loss : 0.301004 +[15:03:28.265] Epoch 123 Evaluation: +[15:07:51.409] average MSE: 0.08026375746864367 average PSNR: 23.95983670446762 average SSIM: 0.603191123887964 +[15:08:27.697] iteration 71500 [36.26 sec]: learning rate : 0.000063 loss : 0.289495 +[15:09:15.060] iteration 71600 [83.63 sec]: learning rate : 0.000063 loss : 0.297440 +[15:10:02.597] iteration 71700 [131.17 sec]: learning rate : 0.000063 loss : 0.307761 +[15:10:50.019] iteration 71800 [178.59 sec]: learning rate : 0.000063 loss : 0.272036 +[15:11:37.387] iteration 71900 [225.96 sec]: learning rate : 0.000063 loss : 0.310687 +[15:12:24.843] iteration 72000 [273.41 sec]: learning rate : 0.000063 loss : 0.354956 +[15:12:24.878] Epoch 124 Evaluation: +[15:16:36.300] average MSE: 0.07670613983832025 average PSNR: 24.159627412218025 average SSIM: 0.6091583572935477 +[15:17:23.958] iteration 72100 [47.63 sec]: learning rate : 0.000063 loss : 0.393500 +[15:18:11.475] iteration 72200 [95.15 sec]: learning rate : 0.000063 loss : 0.378073 +[15:18:58.965] iteration 72300 [142.64 sec]: learning rate : 0.000063 loss : 0.206984 +[15:19:46.449] iteration 72400 [190.13 sec]: learning rate : 0.000063 loss : 0.264092 +[15:20:33.829] iteration 72500 [237.51 sec]: learning rate : 0.000063 loss : 0.331310 +[15:21:09.971] Epoch 125 Evaluation: +[15:25:24.982] average MSE: 0.0820996659073099 average PSNR: 23.85987876539848 average SSIM: 0.6028911521752258 +[15:25:36.566] iteration 72600 [11.56 sec]: learning rate : 0.000063 loss : 0.267803 +[15:26:23.871] iteration 72700 [58.87 sec]: learning rate : 0.000063 loss : 0.327029 +[15:27:11.336] iteration 72800 [106.33 sec]: learning rate : 0.000063 loss : 0.423284 +[15:27:58.734] iteration 72900 [153.73 sec]: learning rate : 0.000063 loss : 0.415928 +[15:28:46.025] iteration 73000 [201.02 sec]: learning rate : 0.000063 loss : 0.348970 +[15:29:33.411] iteration 73100 [248.40 sec]: learning rate : 0.000063 loss : 0.239426 +[15:29:58.004] Epoch 126 Evaluation: +[15:34:12.442] average MSE: 0.08080112352931079 average PSNR: 23.92910214272987 average SSIM: 0.604088917829828 +[15:34:35.496] iteration 73200 [23.03 sec]: learning rate : 0.000063 loss : 0.349521 +[15:35:22.804] iteration 73300 [70.34 sec]: learning rate : 0.000063 loss : 0.354506 +[15:36:10.198] iteration 73400 [117.74 sec]: learning rate : 0.000063 loss : 0.342304 +[15:36:57.645] iteration 73500 [165.18 sec]: learning rate : 0.000063 loss : 0.311317 +[15:37:44.936] iteration 73600 [212.47 sec]: learning rate : 0.000063 loss : 0.353120 +[15:38:32.315] iteration 73700 [259.85 sec]: learning rate : 0.000063 loss : 0.325014 +[15:38:45.559] Epoch 127 Evaluation: +[15:43:02.546] average MSE: 0.09150413690488529 average PSNR: 23.387374153181817 average SSIM: 0.5979804399779765 +[15:43:36.781] iteration 73800 [34.21 sec]: learning rate : 0.000063 loss : 0.363872 +[15:44:24.230] iteration 73900 [81.66 sec]: learning rate : 0.000063 loss : 0.274586 +[15:45:11.574] iteration 74000 [129.00 sec]: learning rate : 0.000063 loss : 0.406479 +[15:45:58.856] iteration 74100 [176.29 sec]: learning rate : 0.000063 loss : 0.294440 +[15:46:46.209] iteration 74200 [223.64 sec]: learning rate : 0.000063 loss : 0.259235 +[15:47:33.955] iteration 74300 [271.39 sec]: learning rate : 0.000063 loss : 0.329089 +[15:47:35.854] Epoch 128 Evaluation: +[15:51:48.329] average MSE: 0.09291837357069087 average PSNR: 23.320794337110502 average SSIM: 0.6018180812932591 +[15:52:34.065] iteration 74400 [45.71 sec]: learning rate : 0.000063 loss : 0.309032 +[15:53:21.614] iteration 74500 [93.26 sec]: learning rate : 0.000063 loss : 0.345525 +[15:54:09.114] iteration 74600 [140.76 sec]: learning rate : 0.000063 loss : 0.291843 +[15:54:56.736] iteration 74700 [188.38 sec]: learning rate : 0.000063 loss : 0.300518 +[15:55:44.262] iteration 74800 [235.91 sec]: learning rate : 0.000063 loss : 0.352047 +[15:56:22.251] Epoch 129 Evaluation: +[16:00:45.868] average MSE: 0.0972296322565003 average PSNR: 23.12599343792586 average SSIM: 0.601159016944898 +[16:00:55.695] iteration 74900 [9.80 sec]: learning rate : 0.000063 loss : 0.347131 +[16:01:43.120] iteration 75000 [57.23 sec]: learning rate : 0.000063 loss : 0.331793 +[16:02:30.639] iteration 75100 [104.75 sec]: learning rate : 0.000063 loss : 0.324461 +[16:03:17.961] iteration 75200 [152.07 sec]: learning rate : 0.000063 loss : 0.357853 +[16:04:05.372] iteration 75300 [199.48 sec]: learning rate : 0.000063 loss : 0.317543 +[16:04:52.853] iteration 75400 [246.96 sec]: learning rate : 0.000063 loss : 0.293094 +[16:05:19.378] Epoch 130 Evaluation: +[16:09:33.566] average MSE: 0.08629114553061873 average PSNR: 23.64245529537021 average SSIM: 0.5998979159737563 +[16:09:54.841] iteration 75500 [21.25 sec]: learning rate : 0.000063 loss : 0.335099 +[16:10:42.277] iteration 75600 [68.69 sec]: learning rate : 0.000063 loss : 0.366128 +[16:11:29.650] iteration 75700 [116.06 sec]: learning rate : 0.000063 loss : 0.331296 +[16:12:17.088] iteration 75800 [163.50 sec]: learning rate : 0.000063 loss : 0.396886 +[16:13:04.488] iteration 75900 [210.90 sec]: learning rate : 0.000063 loss : 0.364344 +[16:13:51.933] iteration 76000 [258.34 sec]: learning rate : 0.000063 loss : 0.258717 +[16:14:07.158] Epoch 131 Evaluation: +[16:18:20.451] average MSE: 0.08585385701718057 average PSNR: 23.664601296639933 average SSIM: 0.5952356196005127 +[16:18:52.831] iteration 76100 [32.36 sec]: learning rate : 0.000063 loss : 0.293793 +[16:19:40.284] iteration 76200 [79.81 sec]: learning rate : 0.000063 loss : 0.324021 +[16:20:27.646] iteration 76300 [127.17 sec]: learning rate : 0.000063 loss : 0.300106 +[16:21:15.243] iteration 76400 [174.77 sec]: learning rate : 0.000063 loss : 0.271505 +[16:22:02.816] iteration 76500 [222.34 sec]: learning rate : 0.000063 loss : 0.381749 +[16:22:50.295] iteration 76600 [269.82 sec]: learning rate : 0.000063 loss : 0.288792 +[16:22:54.097] Epoch 132 Evaluation: +[16:27:19.831] average MSE: 0.08629227590390162 average PSNR: 23.642810211975426 average SSIM: 0.5968798187753083 +[16:28:03.899] iteration 76700 [44.05 sec]: learning rate : 0.000063 loss : 0.334308 +[16:28:51.493] iteration 76800 [91.64 sec]: learning rate : 0.000063 loss : 0.398266 +[16:29:38.957] iteration 76900 [139.10 sec]: learning rate : 0.000063 loss : 0.425125 +[16:30:26.513] iteration 77000 [186.66 sec]: learning rate : 0.000063 loss : 0.387076 +[16:31:14.058] iteration 77100 [234.20 sec]: learning rate : 0.000063 loss : 0.295376 +[16:31:53.925] Epoch 133 Evaluation: +[16:36:19.657] average MSE: 0.08805249555572243 average PSNR: 23.553681338078796 average SSIM: 0.5976263421591493 +[16:36:27.415] iteration 77200 [7.74 sec]: learning rate : 0.000063 loss : 0.320688 +[16:37:14.896] iteration 77300 [55.22 sec]: learning rate : 0.000063 loss : 0.381141 +[16:38:02.173] iteration 77400 [102.49 sec]: learning rate : 0.000063 loss : 0.356608 +[16:38:49.564] iteration 77500 [149.88 sec]: learning rate : 0.000063 loss : 0.319656 +[16:39:37.125] iteration 77600 [197.44 sec]: learning rate : 0.000063 loss : 0.264194 +[16:40:24.446] iteration 77700 [244.76 sec]: learning rate : 0.000063 loss : 0.358967 +[16:40:52.958] Epoch 134 Evaluation: +[16:45:07.669] average MSE: 0.08847241530716442 average PSNR: 23.534087554626453 average SSIM: 0.5950681997956369 +[16:45:26.826] iteration 77800 [19.13 sec]: learning rate : 0.000063 loss : 0.323887 +[16:46:14.317] iteration 77900 [66.62 sec]: learning rate : 0.000063 loss : 0.352040 +[16:47:02.028] iteration 78000 [114.34 sec]: learning rate : 0.000063 loss : 0.339467 +[16:47:49.462] iteration 78100 [161.77 sec]: learning rate : 0.000063 loss : 0.363224 +[16:48:36.989] iteration 78200 [209.30 sec]: learning rate : 0.000063 loss : 0.318948 +[16:49:24.268] iteration 78300 [256.58 sec]: learning rate : 0.000063 loss : 0.291825 +[16:49:41.296] Epoch 135 Evaluation: +[16:53:53.538] average MSE: 0.08571842579138776 average PSNR: 23.672142994705734 average SSIM: 0.5958953363700239 +[16:54:24.457] iteration 78400 [30.90 sec]: learning rate : 0.000063 loss : 0.349442 +[16:55:11.944] iteration 78500 [78.38 sec]: learning rate : 0.000063 loss : 0.330072 +[16:55:59.542] iteration 78600 [125.98 sec]: learning rate : 0.000063 loss : 0.322710 +[16:56:47.160] iteration 78700 [173.60 sec]: learning rate : 0.000063 loss : 0.415643 +[16:57:34.436] iteration 78800 [220.87 sec]: learning rate : 0.000063 loss : 0.266997 +[16:58:21.881] iteration 78900 [268.32 sec]: learning rate : 0.000063 loss : 0.252067 +[16:58:27.559] Epoch 136 Evaluation: +[17:02:40.605] average MSE: 0.0904990772224593 average PSNR: 23.4355606923785 average SSIM: 0.5939534372267462 +[17:03:22.597] iteration 79000 [41.97 sec]: learning rate : 0.000063 loss : 0.365324 +[17:04:09.897] iteration 79100 [89.27 sec]: learning rate : 0.000063 loss : 0.303854 +[17:04:57.290] iteration 79200 [136.66 sec]: learning rate : 0.000063 loss : 0.371519 +[17:05:44.597] iteration 79300 [183.97 sec]: learning rate : 0.000063 loss : 0.401476 +[17:06:32.046] iteration 79400 [231.42 sec]: learning rate : 0.000063 loss : 0.298504 +[17:07:13.900] Epoch 137 Evaluation: +[17:11:28.075] average MSE: 0.08703444755412566 average PSNR: 23.604942614053193 average SSIM: 0.5958105387007061 +[17:11:33.930] iteration 79500 [5.83 sec]: learning rate : 0.000063 loss : 0.320568 +[17:12:21.255] iteration 79600 [53.16 sec]: learning rate : 0.000063 loss : 0.263406 +[17:13:08.710] iteration 79700 [100.61 sec]: learning rate : 0.000063 loss : 0.324105 +[17:13:56.121] iteration 79800 [148.02 sec]: learning rate : 0.000063 loss : 0.312223 +[17:14:43.400] iteration 79900 [195.30 sec]: learning rate : 0.000063 loss : 0.336048 +[17:15:30.783] iteration 80000 [242.68 sec]: learning rate : 0.000016 loss : 0.275559 +[17:15:30.940] save model to model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time/iter_80000.pth +[17:16:01.220] Epoch 138 Evaluation: +[17:20:13.749] average MSE: 0.08516419539928599 average PSNR: 23.7003475488251 average SSIM: 0.5958624254433957 +[17:20:31.154] iteration 80100 [17.38 sec]: learning rate : 0.000031 loss : 0.337406 +[17:21:18.451] iteration 80200 [64.67 sec]: learning rate : 0.000031 loss : 0.341736 +[17:22:05.813] iteration 80300 [112.04 sec]: learning rate : 0.000031 loss : 0.270572 +[17:22:53.206] iteration 80400 [159.43 sec]: learning rate : 0.000031 loss : 0.322405 +[17:23:40.488] iteration 80500 [206.71 sec]: learning rate : 0.000031 loss : 0.269837 +[17:24:28.164] iteration 80600 [254.39 sec]: learning rate : 0.000031 loss : 0.237562 +[17:24:47.090] Epoch 139 Evaluation: +[17:29:00.544] average MSE: 0.08867716215637722 average PSNR: 23.52395588903089 average SSIM: 0.5961596395039317 +[17:29:29.096] iteration 80700 [28.53 sec]: learning rate : 0.000031 loss : 0.307078 +[17:30:16.571] iteration 80800 [76.00 sec]: learning rate : 0.000031 loss : 0.302926 +[17:31:04.117] iteration 80900 [123.55 sec]: learning rate : 0.000031 loss : 0.377611 +[17:31:51.771] iteration 81000 [171.20 sec]: learning rate : 0.000031 loss : 0.391945 +[17:32:39.292] iteration 81100 [218.73 sec]: learning rate : 0.000031 loss : 0.328367 +[17:33:26.821] iteration 81200 [266.26 sec]: learning rate : 0.000031 loss : 0.315590 +[17:33:34.412] Epoch 140 Evaluation: +[17:37:46.459] average MSE: 0.09076253437490876 average PSNR: 23.42259360430759 average SSIM: 0.5951054092806579 +[17:38:26.542] iteration 81300 [40.06 sec]: learning rate : 0.000031 loss : 0.343847 +[17:39:14.178] iteration 81400 [87.69 sec]: learning rate : 0.000031 loss : 0.694497 +[17:40:01.694] iteration 81500 [135.21 sec]: learning rate : 0.000031 loss : 0.318844 +[17:40:49.133] iteration 81600 [182.65 sec]: learning rate : 0.000031 loss : 0.311067 +[17:41:36.646] iteration 81700 [230.16 sec]: learning rate : 0.000031 loss : 0.302936 +[17:42:20.406] Epoch 141 Evaluation: +[17:46:33.795] average MSE: 0.09302522583245199 average PSNR: 23.3162747738277 average SSIM: 0.5970465544731111 +[17:46:37.905] iteration 81800 [4.09 sec]: learning rate : 0.000031 loss : 0.402789 +[17:47:25.202] iteration 81900 [51.38 sec]: learning rate : 0.000031 loss : 0.295364 +[17:48:12.564] iteration 82000 [98.76 sec]: learning rate : 0.000031 loss : 0.349493 +[17:48:59.892] iteration 82100 [146.07 sec]: learning rate : 0.000031 loss : 0.325944 +[17:49:47.291] iteration 82200 [193.47 sec]: learning rate : 0.000031 loss : 0.302482 +[17:50:34.692] iteration 82300 [240.87 sec]: learning rate : 0.000031 loss : 0.325755 +[17:51:06.889] Epoch 142 Evaluation: +[17:55:20.907] average MSE: 0.09184846745035478 average PSNR: 23.370640569880276 average SSIM: 0.5945948514385427 +[17:55:36.220] iteration 82400 [15.29 sec]: learning rate : 0.000031 loss : 0.359836 +[17:56:23.682] iteration 82500 [62.75 sec]: learning rate : 0.000031 loss : 0.367967 +[17:57:11.042] iteration 82600 [110.11 sec]: learning rate : 0.000031 loss : 0.375608 +[17:57:58.360] iteration 82700 [157.43 sec]: learning rate : 0.000031 loss : 0.284824 +[17:58:45.765] iteration 82800 [204.83 sec]: learning rate : 0.000031 loss : 0.329479 +[17:59:33.135] iteration 82900 [252.20 sec]: learning rate : 0.000031 loss : 0.325083 +[17:59:53.938] Epoch 143 Evaluation: +[18:04:07.084] average MSE: 0.08843814000431484 average PSNR: 23.536564110766786 average SSIM: 0.5928691068833983 +[18:04:33.799] iteration 83000 [26.69 sec]: learning rate : 0.000031 loss : 0.282465 +[18:05:21.394] iteration 83100 [74.29 sec]: learning rate : 0.000031 loss : 0.255827 +[18:06:08.849] iteration 83200 [121.74 sec]: learning rate : 0.000031 loss : 0.319033 +[18:06:56.315] iteration 83300 [169.21 sec]: learning rate : 0.000031 loss : 0.330988 +[18:07:43.760] iteration 83400 [216.65 sec]: learning rate : 0.000031 loss : 0.253027 +[18:08:31.046] iteration 83500 [263.94 sec]: learning rate : 0.000031 loss : 0.320414 +[18:08:40.595] Epoch 144 Evaluation: +[18:12:52.646] average MSE: 0.08746568066398734 average PSNR: 23.584170821657327 average SSIM: 0.5956751849388835 +[18:13:30.676] iteration 83600 [38.01 sec]: learning rate : 0.000031 loss : 0.419870 +[18:14:18.115] iteration 83700 [85.44 sec]: learning rate : 0.000031 loss : 0.246951 +[18:15:05.401] iteration 83800 [132.73 sec]: learning rate : 0.000031 loss : 0.648370 +[18:15:52.797] iteration 83900 [180.13 sec]: learning rate : 0.000031 loss : 0.314090 +[18:16:40.164] iteration 84000 [227.49 sec]: learning rate : 0.000031 loss : 0.309298 +[18:17:25.844] Epoch 145 Evaluation: +[18:21:39.673] average MSE: 0.0842376374852002 average PSNR: 23.748624955098425 average SSIM: 0.5978044734049258 +[18:21:41.755] iteration 84100 [2.06 sec]: learning rate : 0.000031 loss : 0.372882 +[18:22:29.627] iteration 84200 [49.93 sec]: learning rate : 0.000031 loss : 0.270905 +[18:23:17.150] iteration 84300 [97.45 sec]: learning rate : 0.000031 loss : 0.285993 +[18:24:04.586] iteration 84400 [144.89 sec]: learning rate : 0.000031 loss : 0.336487 +[18:24:52.109] iteration 84500 [192.41 sec]: learning rate : 0.000031 loss : 0.326378 +[18:25:39.518] iteration 84600 [239.82 sec]: learning rate : 0.000031 loss : 0.382500 +[18:26:13.779] Epoch 146 Evaluation: +[18:30:40.252] average MSE: 0.0858863808870997 average PSNR: 23.66399508150318 average SSIM: 0.5965468942233575 +[18:30:53.674] iteration 84700 [13.40 sec]: learning rate : 0.000031 loss : 0.347554 +[18:31:41.095] iteration 84800 [60.82 sec]: learning rate : 0.000031 loss : 0.331283 +[18:32:28.322] iteration 84900 [108.05 sec]: learning rate : 0.000031 loss : 0.261350 +[18:33:15.652] iteration 85000 [155.38 sec]: learning rate : 0.000031 loss : 0.307648 +[18:34:02.996] iteration 85100 [202.72 sec]: learning rate : 0.000031 loss : 0.303619 +[18:34:50.243] iteration 85200 [249.97 sec]: learning rate : 0.000031 loss : 0.295881 +[18:35:13.031] Epoch 147 Evaluation: +[18:39:24.629] average MSE: 0.09121857833331483 average PSNR: 23.40136471387551 average SSIM: 0.5966884796348376 +[18:39:49.375] iteration 85300 [24.72 sec]: learning rate : 0.000031 loss : 0.280600 +[18:40:36.776] iteration 85400 [72.12 sec]: learning rate : 0.000031 loss : 0.298839 +[18:41:24.009] iteration 85500 [119.36 sec]: learning rate : 0.000031 loss : 0.349275 +[18:42:11.339] iteration 85600 [166.69 sec]: learning rate : 0.000031 loss : 0.310753 +[18:42:58.589] iteration 85700 [213.94 sec]: learning rate : 0.000031 loss : 0.319926 +[18:43:45.925] iteration 85800 [261.27 sec]: learning rate : 0.000031 loss : 0.292075 +[18:43:57.267] Epoch 148 Evaluation: +[18:48:10.209] average MSE: 0.08558340061382383 average PSNR: 23.679416694956544 average SSIM: 0.5998307483921856 +[18:48:46.471] iteration 85900 [36.24 sec]: learning rate : 0.000031 loss : 0.292930 +[18:49:33.732] iteration 86000 [83.50 sec]: learning rate : 0.000031 loss : 0.354971 +[18:50:21.075] iteration 86100 [130.84 sec]: learning rate : 0.000031 loss : 0.337205 +[18:51:08.437] iteration 86200 [178.20 sec]: learning rate : 0.000031 loss : 0.259214 +[18:51:55.663] iteration 86300 [225.43 sec]: learning rate : 0.000031 loss : 0.307457 +[18:52:42.993] iteration 86400 [272.76 sec]: learning rate : 0.000031 loss : 0.326890 +[18:52:43.028] Epoch 149 Evaluation: +[18:56:55.793] average MSE: 0.08513376010066068 average PSNR: 23.70190379156585 average SSIM: 0.5993685570799929 +[18:57:43.369] iteration 86500 [47.55 sec]: learning rate : 0.000031 loss : 0.319128 +[18:58:30.617] iteration 86600 [94.80 sec]: learning rate : 0.000031 loss : 0.329022 +[18:59:17.968] iteration 86700 [142.15 sec]: learning rate : 0.000031 loss : 0.231152 +[19:00:05.333] iteration 86800 [189.52 sec]: learning rate : 0.000031 loss : 0.274557 +[19:00:52.571] iteration 86900 [236.75 sec]: learning rate : 0.000031 loss : 0.296073 +[19:01:28.592] Epoch 150 Evaluation: +[19:05:40.596] average MSE: 0.0848180552422272 average PSNR: 23.718303639955874 average SSIM: 0.5997320815179692 +[19:05:52.121] iteration 87000 [11.50 sec]: learning rate : 0.000031 loss : 0.300009 +[19:06:39.383] iteration 87100 [58.76 sec]: learning rate : 0.000031 loss : 0.369990 +[19:07:26.786] iteration 87200 [106.17 sec]: learning rate : 0.000031 loss : 0.360018 +[19:08:14.140] iteration 87300 [153.52 sec]: learning rate : 0.000031 loss : 0.368951 +[19:09:01.398] iteration 87400 [200.78 sec]: learning rate : 0.000031 loss : 0.341117 +[19:09:48.744] iteration 87500 [248.12 sec]: learning rate : 0.000031 loss : 0.292174 +[19:10:13.324] Epoch 151 Evaluation: +[19:14:25.758] average MSE: 0.08691449179292052 average PSNR: 23.611334069757806 average SSIM: 0.5971547416935896 +[19:14:48.759] iteration 87600 [22.98 sec]: learning rate : 0.000031 loss : 0.308144 +[19:15:36.001] iteration 87700 [70.22 sec]: learning rate : 0.000031 loss : 0.351883 +[19:16:23.333] iteration 87800 [117.55 sec]: learning rate : 0.000031 loss : 0.338996 +[19:17:10.654] iteration 87900 [164.87 sec]: learning rate : 0.000031 loss : 0.351489 +[19:17:57.893] iteration 88000 [212.11 sec]: learning rate : 0.000031 loss : 0.351631 +[19:18:45.257] iteration 88100 [259.48 sec]: learning rate : 0.000031 loss : 0.353475 +[19:18:58.494] Epoch 152 Evaluation: +[19:23:09.106] average MSE: 0.09024174593308212 average PSNR: 23.44775870533755 average SSIM: 0.5974754793528649 +[19:23:43.311] iteration 88200 [34.18 sec]: learning rate : 0.000031 loss : 0.342029 +[19:24:30.702] iteration 88300 [81.57 sec]: learning rate : 0.000031 loss : 0.314511 +[19:25:18.027] iteration 88400 [128.90 sec]: learning rate : 0.000031 loss : 0.395314 +[19:26:05.272] iteration 88500 [176.14 sec]: learning rate : 0.000031 loss : 0.299957 +[19:26:52.609] iteration 88600 [223.48 sec]: learning rate : 0.000031 loss : 0.282914 +[19:27:39.942] iteration 88700 [270.81 sec]: learning rate : 0.000031 loss : 0.339245 +[19:27:41.840] Epoch 153 Evaluation: +[19:31:54.820] average MSE: 0.08805294945390275 average PSNR: 23.554061980789612 average SSIM: 0.5963791888133919 +[19:32:40.501] iteration 88800 [45.66 sec]: learning rate : 0.000031 loss : 0.308953 +[19:33:27.962] iteration 88900 [93.12 sec]: learning rate : 0.000031 loss : 0.343385 +[19:34:15.290] iteration 89000 [140.45 sec]: learning rate : 0.000031 loss : 0.324587 +[19:35:02.525] iteration 89100 [187.68 sec]: learning rate : 0.000031 loss : 0.363150 +[19:35:49.819] iteration 89200 [234.98 sec]: learning rate : 0.000031 loss : 0.324637 +[19:36:27.624] Epoch 154 Evaluation: +[19:40:38.430] average MSE: 0.09815625598656295 average PSNR: 23.085318868069287 average SSIM: 0.5994340736102317 +[19:40:48.060] iteration 89300 [9.61 sec]: learning rate : 0.000031 loss : 0.333633 +[19:41:35.454] iteration 89400 [57.00 sec]: learning rate : 0.000031 loss : 0.380018 +[19:42:22.750] iteration 89500 [104.30 sec]: learning rate : 0.000031 loss : 0.302207 +[19:43:09.979] iteration 89600 [151.52 sec]: learning rate : 0.000031 loss : 0.316849 +[19:43:57.285] iteration 89700 [198.83 sec]: learning rate : 0.000031 loss : 0.293116 +[19:44:44.614] iteration 89800 [246.16 sec]: learning rate : 0.000031 loss : 0.322625 +[19:45:11.071] Epoch 155 Evaluation: +[19:49:22.247] average MSE: 0.08604374969858096 average PSNR: 23.655304904452233 average SSIM: 0.5952507589564227 +[19:49:43.212] iteration 89900 [20.94 sec]: learning rate : 0.000031 loss : 0.317754 +[19:50:30.602] iteration 90000 [68.33 sec]: learning rate : 0.000031 loss : 0.340005 +[19:51:17.904] iteration 90100 [115.63 sec]: learning rate : 0.000031 loss : 0.289783 +[19:52:05.130] iteration 90200 [162.86 sec]: learning rate : 0.000031 loss : 0.374804 +[19:52:52.435] iteration 90300 [210.16 sec]: learning rate : 0.000031 loss : 0.321266 +[19:53:39.666] iteration 90400 [257.40 sec]: learning rate : 0.000031 loss : 0.275065 +[19:53:54.883] Epoch 156 Evaluation: +[19:58:13.875] average MSE: 0.09057285516823739 average PSNR: 23.431282802107273 average SSIM: 0.5974755695147858 +[19:58:46.319] iteration 90500 [32.42 sec]: learning rate : 0.000031 loss : 0.308280 +[19:59:33.901] iteration 90600 [80.00 sec]: learning rate : 0.000031 loss : 0.319602 +[20:00:21.150] iteration 90700 [127.25 sec]: learning rate : 0.000031 loss : 0.298562 +[20:01:08.468] iteration 90800 [174.57 sec]: learning rate : 0.000031 loss : 0.261119 +[20:01:55.782] iteration 90900 [221.88 sec]: learning rate : 0.000031 loss : 0.362256 +[20:02:43.029] iteration 91000 [269.13 sec]: learning rate : 0.000031 loss : 0.378057 +[20:02:46.821] Epoch 157 Evaluation: +[20:06:57.103] average MSE: 0.08942070850596899 average PSNR: 23.48678014261033 average SSIM: 0.5945870804456798 +[20:07:40.861] iteration 91100 [43.73 sec]: learning rate : 0.000031 loss : 0.304946 +[20:08:28.259] iteration 91200 [91.13 sec]: learning rate : 0.000031 loss : 0.423159 +[20:09:15.637] iteration 91300 [138.51 sec]: learning rate : 0.000031 loss : 0.336694 +[20:10:02.957] iteration 91400 [185.83 sec]: learning rate : 0.000031 loss : 0.399332 +[20:10:50.170] iteration 91500 [233.04 sec]: learning rate : 0.000031 loss : 0.314073 +[20:11:29.974] Epoch 158 Evaluation: +[20:15:46.390] average MSE: 0.08709384109351541 average PSNR: 23.602298352386576 average SSIM: 0.5986586610460628 +[20:15:54.166] iteration 91600 [7.75 sec]: learning rate : 0.000031 loss : 0.339023 +[20:16:41.615] iteration 91700 [55.20 sec]: learning rate : 0.000031 loss : 0.365109 +[20:17:28.844] iteration 91800 [102.43 sec]: learning rate : 0.000031 loss : 0.370055 +[20:18:16.179] iteration 91900 [149.76 sec]: learning rate : 0.000031 loss : 0.363242 +[20:19:03.521] iteration 92000 [197.11 sec]: learning rate : 0.000031 loss : 0.296721 +[20:19:50.919] iteration 92100 [244.50 sec]: learning rate : 0.000031 loss : 0.362998 +[20:20:19.381] Epoch 159 Evaluation: +[20:24:43.465] average MSE: 0.08801557563455623 average PSNR: 23.55594443296422 average SSIM: 0.5998956523052532 +[20:25:02.556] iteration 92200 [19.07 sec]: learning rate : 0.000031 loss : 0.296902 +[20:25:49.995] iteration 92300 [66.50 sec]: learning rate : 0.000031 loss : 0.333189 +[20:26:37.306] iteration 92400 [113.82 sec]: learning rate : 0.000031 loss : 0.335192 +[20:27:24.667] iteration 92500 [161.18 sec]: learning rate : 0.000031 loss : 0.345729 +[20:28:11.986] iteration 92600 [208.50 sec]: learning rate : 0.000031 loss : 0.313964 +[20:28:59.490] iteration 92700 [256.00 sec]: learning rate : 0.000031 loss : 0.248915 +[20:29:16.513] Epoch 160 Evaluation: +[20:33:28.521] average MSE: 0.08746549142651902 average PSNR: 23.58384937823133 average SSIM: 0.5982274580773609 +[20:33:59.094] iteration 92800 [30.55 sec]: learning rate : 0.000031 loss : 0.342554 +[20:34:46.362] iteration 92900 [77.82 sec]: learning rate : 0.000031 loss : 0.335847 +[20:35:33.707] iteration 93000 [125.16 sec]: learning rate : 0.000031 loss : 0.252287 +[20:36:21.050] iteration 93100 [172.50 sec]: learning rate : 0.000031 loss : 0.316572 +[20:37:08.292] iteration 93200 [219.75 sec]: learning rate : 0.000031 loss : 0.289652 +[20:37:55.638] iteration 93300 [267.09 sec]: learning rate : 0.000031 loss : 0.293152 +[20:38:01.309] Epoch 161 Evaluation: +[20:42:12.603] average MSE: 0.0908948907244661 average PSNR: 23.41646642264121 average SSIM: 0.5992631720259277 +[20:42:54.593] iteration 93400 [41.97 sec]: learning rate : 0.000031 loss : 0.349669 +[20:43:42.013] iteration 93500 [89.39 sec]: learning rate : 0.000031 loss : 0.292131 +[20:44:29.487] iteration 93600 [136.86 sec]: learning rate : 0.000031 loss : 0.345486 +[20:45:16.989] iteration 93700 [184.36 sec]: learning rate : 0.000031 loss : 0.428537 +[20:46:04.398] iteration 93800 [231.77 sec]: learning rate : 0.000031 loss : 0.347806 +[20:46:46.113] Epoch 162 Evaluation: +[20:50:59.605] average MSE: 0.09091512709898959 average PSNR: 23.415682536912836 average SSIM: 0.5981859358220509 +[20:51:05.489] iteration 93900 [5.86 sec]: learning rate : 0.000031 loss : 0.349282 +[20:51:53.094] iteration 94000 [53.46 sec]: learning rate : 0.000031 loss : 0.331163 +[20:52:40.522] iteration 94100 [100.89 sec]: learning rate : 0.000031 loss : 0.333570 +[20:53:28.047] iteration 94200 [148.42 sec]: learning rate : 0.000031 loss : 0.336541 +[20:54:15.456] iteration 94300 [195.83 sec]: learning rate : 0.000031 loss : 0.327157 +[20:55:03.008] iteration 94400 [243.38 sec]: learning rate : 0.000031 loss : 0.262900 +[20:55:33.368] Epoch 163 Evaluation: +[20:59:58.211] average MSE: 0.09526308671244117 average PSNR: 23.21373965281788 average SSIM: 0.598426523986129 +[21:00:15.642] iteration 94500 [17.41 sec]: learning rate : 0.000031 loss : 0.323912 +[21:01:02.984] iteration 94600 [64.75 sec]: learning rate : 0.000031 loss : 0.364428 +[21:01:50.389] iteration 94700 [112.15 sec]: learning rate : 0.000031 loss : 0.294990 +[21:02:37.752] iteration 94800 [159.52 sec]: learning rate : 0.000031 loss : 0.290515 +[21:03:25.037] iteration 94900 [206.80 sec]: learning rate : 0.000031 loss : 0.303892 +[21:04:14.485] iteration 95000 [256.27 sec]: learning rate : 0.000031 loss : 0.263780 +[21:04:34.199] Epoch 164 Evaluation: +[21:09:09.570] average MSE: 0.0947119171304278 average PSNR: 23.238821297933885 average SSIM: 0.5965737235079086 +[21:09:38.753] iteration 95100 [29.16 sec]: learning rate : 0.000031 loss : 0.328515 +[21:10:26.529] iteration 95200 [76.94 sec]: learning rate : 0.000031 loss : 0.278313 +[21:11:14.760] iteration 95300 [125.19 sec]: learning rate : 0.000031 loss : 0.418723 +[21:12:02.616] iteration 95400 [173.02 sec]: learning rate : 0.000031 loss : 0.342949 +[21:12:50.883] iteration 95500 [221.31 sec]: learning rate : 0.000031 loss : 0.403657 +[21:13:38.978] iteration 95600 [269.39 sec]: learning rate : 0.000031 loss : 0.328753 +[21:13:46.673] Epoch 165 Evaluation: +[21:18:18.907] average MSE: 0.09161040815561042 average PSNR: 23.38224556690292 average SSIM: 0.5952488797042502 +[21:18:59.150] iteration 95700 [40.22 sec]: learning rate : 0.000031 loss : 0.377561 +[21:19:47.152] iteration 95800 [88.23 sec]: learning rate : 0.000031 loss : 0.345810 +[21:20:34.994] iteration 95900 [136.06 sec]: learning rate : 0.000031 loss : 0.340078 +[21:21:22.749] iteration 96000 [183.82 sec]: learning rate : 0.000031 loss : 0.234590 +[21:22:10.679] iteration 96100 [231.75 sec]: learning rate : 0.000031 loss : 0.287376 +[21:22:54.711] Epoch 166 Evaluation: +[21:27:19.494] average MSE: 0.08815216925836042 average PSNR: 23.54997506344478 average SSIM: 0.5969141114828541 +[21:27:23.502] iteration 96200 [3.98 sec]: learning rate : 0.000031 loss : 0.417144 +[21:28:11.051] iteration 96300 [51.53 sec]: learning rate : 0.000031 loss : 0.294322 +[21:28:58.711] iteration 96400 [99.19 sec]: learning rate : 0.000031 loss : 0.324215 +[21:29:46.391] iteration 96500 [146.87 sec]: learning rate : 0.000031 loss : 0.312632 +[21:30:33.938] iteration 96600 [194.42 sec]: learning rate : 0.000031 loss : 0.345433 +[21:31:21.582] iteration 96700 [242.07 sec]: learning rate : 0.000031 loss : 0.313912 +[21:31:53.913] Epoch 167 Evaluation: +[21:36:09.339] average MSE: 0.1017631480633969 average PSNR: 22.92997800104665 average SSIM: 0.6057050979245052 +[21:36:24.704] iteration 96800 [15.34 sec]: learning rate : 0.000031 loss : 0.320290 +[21:37:12.387] iteration 96900 [63.02 sec]: learning rate : 0.000031 loss : 0.370471 +[21:37:59.920] iteration 97000 [110.56 sec]: learning rate : 0.000031 loss : 0.412411 +[21:38:47.473] iteration 97100 [158.11 sec]: learning rate : 0.000031 loss : 0.251570 +[21:39:35.235] iteration 97200 [205.87 sec]: learning rate : 0.000031 loss : 0.357823 +[21:40:23.047] iteration 97300 [253.69 sec]: learning rate : 0.000031 loss : 0.365984 +[21:40:43.942] Epoch 168 Evaluation: +[21:45:01.051] average MSE: 0.0954475306846593 average PSNR: 23.20522038343092 average SSIM: 0.6002159005805364 +[21:45:27.699] iteration 97400 [26.62 sec]: learning rate : 0.000031 loss : 0.288936 +[21:46:15.090] iteration 97500 [74.01 sec]: learning rate : 0.000031 loss : 0.250054 +[21:47:02.504] iteration 97600 [121.43 sec]: learning rate : 0.000031 loss : 0.348146 +[21:47:49.748] iteration 97700 [168.67 sec]: learning rate : 0.000031 loss : 0.292356 +[21:48:37.091] iteration 97800 [216.02 sec]: learning rate : 0.000031 loss : 0.229192 +[21:49:24.430] iteration 97900 [263.36 sec]: learning rate : 0.000031 loss : 0.422173 +[21:49:33.982] Epoch 169 Evaluation: +[21:53:47.891] average MSE: 0.09506812734438186 average PSNR: 23.22235003580248 average SSIM: 0.596511156580592 +[21:54:25.860] iteration 98000 [37.95 sec]: learning rate : 0.000031 loss : 0.425657 +[21:55:13.224] iteration 98100 [85.31 sec]: learning rate : 0.000031 loss : 0.226705 +[21:56:00.462] iteration 98200 [132.55 sec]: learning rate : 0.000031 loss : 0.277875 +[21:56:47.814] iteration 98300 [179.90 sec]: learning rate : 0.000031 loss : 0.308913 +[21:57:35.140] iteration 98400 [227.23 sec]: learning rate : 0.000031 loss : 0.260142 +[21:58:20.501] Epoch 170 Evaluation: +[22:02:32.723] average MSE: 0.08744926632784641 average PSNR: 23.5849711177071 average SSIM: 0.5954700952618779 +[22:02:34.804] iteration 98500 [2.06 sec]: learning rate : 0.000031 loss : 0.380053 +[22:03:22.228] iteration 98600 [49.48 sec]: learning rate : 0.000031 loss : 0.325020 +[22:04:09.590] iteration 98700 [96.84 sec]: learning rate : 0.000031 loss : 0.268120 +[22:04:56.874] iteration 98800 [144.13 sec]: learning rate : 0.000031 loss : 0.315509 +[22:05:44.218] iteration 98900 [191.47 sec]: learning rate : 0.000031 loss : 0.345702 +[22:06:31.444] iteration 99000 [238.70 sec]: learning rate : 0.000031 loss : 0.321509 +[22:07:05.572] Epoch 171 Evaluation: +===> Evaluate Metric <=== +Results +------------------------------------ +ColdDiffusion NMSE: 4.0699 ± 0.3824 +ColdDiffusion PSNR: 27.8420 ± 0.4646 +ColdDiffusion SSIM: 0.7499 ± 0.0109 +------------------------------------ +All NMSE: 4.0620 ± 0.6612 +All PSNR: 26.7950 ± 0.7506 +All SSIM: 0.7222 ± 0.0210 +------------------------------------[22:11:18.446] average MSE: 0.08633650612099658 average PSNR: 23.640632693181043 average SSIM: 0.5967397778743725 +[22:11:31.900] iteration 99100 [13.43 sec]: learning rate : 0.000031 loss : 0.379974 +[22:12:19.404] iteration 99200 [60.94 sec]: learning rate : 0.000031 loss : 0.330017 +[22:13:06.635] iteration 99300 [108.17 sec]: learning rate : 0.000031 loss : 0.263681 +[22:13:53.968] iteration 99400 [155.50 sec]: learning rate : 0.000031 loss : 0.519775 +[22:14:41.286] iteration 99500 [202.82 sec]: learning rate : 0.000031 loss : 0.372014 +[22:15:28.523] iteration 99600 [250.05 sec]: learning rate : 0.000031 loss : 0.291419 +[22:15:51.282] Epoch 172 Evaluation: +[22:20:03.171] average MSE: 0.09507555797765292 average PSNR: 23.221624544918896 average SSIM: 0.5957881855697643 +[22:20:28.006] iteration 99700 [24.81 sec]: learning rate : 0.000031 loss : 0.249940 +[22:21:15.738] iteration 99800 [72.54 sec]: learning rate : 0.000031 loss : 0.338559 +[22:22:03.145] iteration 99900 [119.95 sec]: learning rate : 0.000031 loss : 0.375334 +[22:22:50.643] iteration 100000 [167.45 sec]: learning rate : 0.000008 loss : 0.316004 +[22:22:50.838] save model to model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time/iter_100000.pth +[22:22:51.325] Epoch 173 Evaluation: +[22:27:06.028] average MSE: 0.08803852952607479 average PSNR: 23.55578282961652 average SSIM: 0.59592781868209 +[22:27:06.335] save model to model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time/iter_100000.pth diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time/log/events.out.tfevents.1752550861.GCRSANDBOX133 b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time/log/events.out.tfevents.1752550861.GCRSANDBOX133 new file mode 100644 index 0000000000000000000000000000000000000000..726c60981e533ef5611ae61d0b507e621c0ac3c8 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t30_new_kspace_time/log/events.out.tfevents.1752550861.GCRSANDBOX133 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f88f3490d1d4901ab5fdb1e51ebde1afa8e0f01b968e97853d03e1b56736dd34 +size 40 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t5_new_kspace_time/best_checkpoint.pth b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t5_new_kspace_time/best_checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..ab71bc9c4911b0624e2e5dc7b49418a8ab623191 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t5_new_kspace_time/best_checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a6e21105c0e5e6925c14bae1a9deb95d178b2c1c94041c73cb6b7b02983f6b2 +size 56614874 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t5_new_kspace_time/log.txt b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t5_new_kspace_time/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..f0a8beda3708357efb67724b792140aa7fdc3ad0 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t5_new_kspace_time/log.txt @@ -0,0 +1,1366 @@ +[20:31:55.781] Namespace(root_path='/home/v-qichen3/MRI_recon/data/m4raw', MRIDOWN='4X', low_field_SNR=0, phase='train', gpu='0', exp='FSMNet_m4raw_4x_lr5e-4', max_iterations=100000, batch_size=4, base_lr=0.0005, seed=1337, resume=None, relation_consistency='False', clip_grad='True', norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='random', CENTER_FRACTIONS=[0.08], ACCELERATIONS=[4], num_timesteps=5, image_size=240, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[20:38:06.890] iteration 100 [52.76 sec]: learning rate : 0.000500 loss : 0.508356 +[20:38:57.621] iteration 200 [103.50 sec]: learning rate : 0.000500 loss : 0.651396 +[20:39:48.431] iteration 300 [154.29 sec]: learning rate : 0.000500 loss : 0.450083 +[20:40:38.802] iteration 400 [204.66 sec]: learning rate : 0.000500 loss : 0.598925 +[20:41:28.807] iteration 500 [254.67 sec]: learning rate : 0.000500 loss : 0.592256 +[20:42:06.849] Epoch 0 Evaluation: +[20:42:52.044] average MSE: 0.05577221530550083 average PSNR: 25.569599154449392 average SSIM: 0.6924716677203395 +[20:43:04.235] iteration 600 [12.17 sec]: learning rate : 0.000500 loss : 0.500626 +[20:43:54.253] iteration 700 [62.18 sec]: learning rate : 0.000500 loss : 0.402587 +[20:44:44.209] iteration 800 [112.14 sec]: learning rate : 0.000500 loss : 0.495099 +[20:45:34.221] iteration 900 [162.15 sec]: learning rate : 0.000500 loss : 0.385608 +[20:46:24.772] iteration 1000 [212.70 sec]: learning rate : 0.000500 loss : 0.367101 +[20:47:14.645] iteration 1100 [262.57 sec]: learning rate : 0.000500 loss : 0.546024 +[20:47:40.685] Epoch 1 Evaluation: +[20:48:26.497] average MSE: 0.04877883951754869 average PSNR: 26.156532421793855 average SSIM: 0.7067415972375548 +[20:48:50.666] iteration 1200 [24.14 sec]: learning rate : 0.000500 loss : 0.397528 +[20:49:40.741] iteration 1300 [74.22 sec]: learning rate : 0.000500 loss : 0.383576 +[20:50:31.240] iteration 1400 [124.72 sec]: learning rate : 0.000500 loss : 0.345545 +[20:51:21.290] iteration 1500 [174.77 sec]: learning rate : 0.000500 loss : 0.494478 +[20:52:11.293] iteration 1600 [224.77 sec]: learning rate : 0.000500 loss : 0.477217 +[20:53:01.331] iteration 1700 [274.81 sec]: learning rate : 0.000500 loss : 0.454170 +[20:53:15.304] Epoch 2 Evaluation: +[20:54:00.411] average MSE: 0.04962896356463407 average PSNR: 26.07867820268169 average SSIM: 0.7015850420542068 +[20:54:36.692] iteration 1800 [36.25 sec]: learning rate : 0.000500 loss : 0.332100 +[20:55:26.965] iteration 1900 [86.53 sec]: learning rate : 0.000500 loss : 0.471132 +[20:56:17.003] iteration 2000 [136.57 sec]: learning rate : 0.000500 loss : 0.537053 +[20:57:07.023] iteration 2100 [186.59 sec]: learning rate : 0.000500 loss : 0.450319 +[20:57:57.268] iteration 2200 [236.83 sec]: learning rate : 0.000500 loss : 1.634137 +[20:58:47.340] iteration 2300 [286.90 sec]: learning rate : 0.000500 loss : 2.225981 +[20:58:49.350] Epoch 3 Evaluation: +[20:59:36.313] average MSE: 0.047822480949510006 average PSNR: 26.24152494926041 average SSIM: 0.7087358032734624 +[21:00:24.429] iteration 2400 [48.09 sec]: learning rate : 0.000500 loss : 0.378685 +[21:01:14.468] iteration 2500 [98.13 sec]: learning rate : 0.000500 loss : 0.452092 +[21:02:04.435] iteration 2600 [148.10 sec]: learning rate : 0.000500 loss : 0.253784 +[21:02:54.638] iteration 2700 [198.30 sec]: learning rate : 0.000500 loss : 0.407612 +[21:03:44.768] iteration 2800 [248.43 sec]: learning rate : 0.000500 loss : 0.402118 +[21:04:25.297] Epoch 4 Evaluation: +[21:05:11.088] average MSE: 0.04704155072197187 average PSNR: 26.309475359932847 average SSIM: 0.7074169853454914 +[21:05:21.276] iteration 2900 [10.16 sec]: learning rate : 0.000500 loss : 0.406876 +[21:06:11.153] iteration 3000 [60.04 sec]: learning rate : 0.000500 loss : 0.455145 +[21:07:01.212] iteration 3100 [110.10 sec]: learning rate : 0.000500 loss : 0.333167 +[21:07:51.209] iteration 3200 [160.09 sec]: learning rate : 0.000500 loss : 0.353548 +[21:08:41.700] iteration 3300 [210.59 sec]: learning rate : 0.000500 loss : 0.388246 +[21:09:31.687] iteration 3400 [260.57 sec]: learning rate : 0.000500 loss : 0.417199 +[21:09:59.912] Epoch 5 Evaluation: +[21:10:46.236] average MSE: 0.04859357621725072 average PSNR: 26.159094914133902 average SSIM: 0.7131373115441086 +[21:11:08.550] iteration 3500 [22.29 sec]: learning rate : 0.000500 loss : 0.440048 +[21:11:58.446] iteration 3600 [72.18 sec]: learning rate : 0.000500 loss : 0.442913 +[21:12:48.514] iteration 3700 [122.25 sec]: learning rate : 0.000500 loss : 0.393228 +[21:13:38.437] iteration 3800 [172.18 sec]: learning rate : 0.000500 loss : 0.441884 +[21:14:28.445] iteration 3900 [222.18 sec]: learning rate : 0.000500 loss : 0.510846 +[21:15:18.453] iteration 4000 [272.19 sec]: learning rate : 0.000500 loss : 0.311368 +[21:15:34.414] Epoch 6 Evaluation: +[21:16:20.922] average MSE: 0.04179657683815543 average PSNR: 26.82775023418517 average SSIM: 0.7288363958934305 +[21:16:55.048] iteration 4100 [34.10 sec]: learning rate : 0.000500 loss : 0.421985 +[21:17:45.363] iteration 4200 [84.42 sec]: learning rate : 0.000500 loss : 0.466873 +[21:18:35.514] iteration 4300 [134.57 sec]: learning rate : 0.000500 loss : 0.343464 +[21:19:25.428] iteration 4400 [184.48 sec]: learning rate : 0.000500 loss : 0.580090 +[21:20:15.436] iteration 4500 [234.49 sec]: learning rate : 0.000500 loss : 0.458488 +[21:21:05.449] iteration 4600 [284.50 sec]: learning rate : 0.000500 loss : 0.414047 +[21:21:09.446] Epoch 7 Evaluation: +[21:21:55.693] average MSE: 0.046237489903743066 average PSNR: 26.38406634045018 average SSIM: 0.7141888267557891 +[21:22:42.189] iteration 4700 [46.47 sec]: learning rate : 0.000500 loss : 0.347152 +[21:23:32.217] iteration 4800 [96.50 sec]: learning rate : 0.000500 loss : 0.487376 +[21:24:22.235] iteration 4900 [146.52 sec]: learning rate : 0.000500 loss : 0.461254 +[21:25:12.188] iteration 5000 [196.47 sec]: learning rate : 0.000500 loss : 0.330032 +[21:26:02.754] iteration 5100 [247.04 sec]: learning rate : 0.000500 loss : 0.478779 +[21:26:44.817] Epoch 8 Evaluation: +[21:27:31.941] average MSE: 0.0431449596036267 average PSNR: 26.680344771529064 average SSIM: 0.7180873765645992 +[21:27:40.143] iteration 5200 [8.18 sec]: learning rate : 0.000500 loss : 0.472951 +[21:28:30.837] iteration 5300 [58.87 sec]: learning rate : 0.000500 loss : 0.442390 +[21:29:20.885] iteration 5400 [108.92 sec]: learning rate : 0.000500 loss : 0.389290 +[21:30:10.867] iteration 5500 [158.90 sec]: learning rate : 0.000500 loss : 0.492017 +[21:31:00.954] iteration 5600 [208.99 sec]: learning rate : 0.000500 loss : 0.346600 +[21:31:51.072] iteration 5700 [259.11 sec]: learning rate : 0.000500 loss : 0.433869 +[21:32:21.150] Epoch 9 Evaluation: +[21:33:07.126] average MSE: 0.045428480781652375 average PSNR: 26.46242344691081 average SSIM: 0.7071758509397815 +[21:33:27.297] iteration 5800 [20.14 sec]: learning rate : 0.000500 loss : 0.400760 +[21:34:17.309] iteration 5900 [70.16 sec]: learning rate : 0.000500 loss : 0.489778 +[21:35:07.945] iteration 6000 [120.79 sec]: learning rate : 0.000500 loss : 0.285342 +[21:35:57.908] iteration 6100 [170.76 sec]: learning rate : 0.000500 loss : 0.504548 +[21:36:47.950] iteration 6200 [220.80 sec]: learning rate : 0.000500 loss : 1.419888 +[21:37:38.016] iteration 6300 [270.86 sec]: learning rate : 0.000500 loss : 0.406041 +[21:37:55.993] Epoch 10 Evaluation: +[21:38:41.768] average MSE: 0.05506314007012765 average PSNR: 25.62057934977427 average SSIM: 0.691538063125597 +[21:39:13.942] iteration 6400 [32.15 sec]: learning rate : 0.000500 loss : 0.372825 +[21:40:03.988] iteration 6500 [82.19 sec]: learning rate : 0.000500 loss : 0.392524 +[21:40:54.619] iteration 6600 [132.83 sec]: learning rate : 0.000500 loss : 0.414928 +[21:41:44.734] iteration 6700 [182.94 sec]: learning rate : 0.000500 loss : 0.394990 +[21:42:34.833] iteration 6800 [233.04 sec]: learning rate : 0.000500 loss : 0.449842 +[21:43:24.829] iteration 6900 [283.03 sec]: learning rate : 0.000500 loss : 0.444103 +[21:43:30.834] Epoch 11 Evaluation: +[21:44:19.068] average MSE: 0.045907566982947505 average PSNR: 26.40755975276817 average SSIM: 0.7033764358102633 +[21:45:03.377] iteration 7000 [44.28 sec]: learning rate : 0.000500 loss : 0.428942 +[21:45:53.382] iteration 7100 [94.29 sec]: learning rate : 0.000500 loss : 0.448294 +[21:46:43.855] iteration 7200 [144.76 sec]: learning rate : 0.000500 loss : 0.466432 +[21:47:33.969] iteration 7300 [194.88 sec]: learning rate : 0.000500 loss : 0.338263 +[21:48:24.477] iteration 7400 [245.38 sec]: learning rate : 0.000500 loss : 1.393043 +[21:49:08.471] Epoch 12 Evaluation: +[21:49:54.954] average MSE: 0.039302503523315074 average PSNR: 27.090284017281363 average SSIM: 0.7396235676154543 +[21:50:01.192] iteration 7500 [6.21 sec]: learning rate : 0.000500 loss : 0.471576 +[21:50:51.345] iteration 7600 [56.37 sec]: learning rate : 0.000500 loss : 0.350155 +[21:51:41.400] iteration 7700 [106.42 sec]: learning rate : 0.000500 loss : 0.407024 +[21:52:31.458] iteration 7800 [156.48 sec]: learning rate : 0.000500 loss : 0.396056 +[21:53:21.803] iteration 7900 [206.82 sec]: learning rate : 0.000500 loss : 0.374480 +[21:54:11.721] iteration 8000 [256.74 sec]: learning rate : 0.000500 loss : 0.337560 +[21:54:43.780] Epoch 13 Evaluation: +[21:55:29.426] average MSE: 0.042422657438107636 average PSNR: 26.757599726535926 average SSIM: 0.7097826686424864 +[21:55:47.671] iteration 8100 [18.22 sec]: learning rate : 0.000500 loss : 0.424589 +[21:56:37.825] iteration 8200 [68.37 sec]: learning rate : 0.000500 loss : 0.351442 +[21:57:27.846] iteration 8300 [118.40 sec]: learning rate : 0.000500 loss : 0.336506 +[21:58:17.947] iteration 8400 [168.50 sec]: learning rate : 0.000500 loss : 0.284910 +[21:59:08.512] iteration 8500 [219.06 sec]: learning rate : 0.000500 loss : 0.431281 +[21:59:58.519] iteration 8600 [269.07 sec]: learning rate : 0.000500 loss : 0.372693 +[22:00:18.529] Epoch 14 Evaluation: +[22:01:05.765] average MSE: 0.04178735431051043 average PSNR: 26.822093417388224 average SSIM: 0.7220457592320081 +[22:01:35.946] iteration 8700 [30.15 sec]: learning rate : 0.000500 loss : 0.353082 +[22:02:26.514] iteration 8800 [80.72 sec]: learning rate : 0.000500 loss : 0.403710 +[22:03:16.864] iteration 8900 [131.07 sec]: learning rate : 0.000500 loss : 0.362573 +[22:04:06.876] iteration 9000 [181.09 sec]: learning rate : 0.000500 loss : 0.355737 +[22:04:56.954] iteration 9100 [231.16 sec]: learning rate : 0.000500 loss : 0.385694 +[22:05:47.065] iteration 9200 [281.27 sec]: learning rate : 0.000500 loss : 0.355071 +[22:05:55.073] Epoch 15 Evaluation: +[22:06:42.276] average MSE: 0.05370488261759135 average PSNR: 25.75570345486164 average SSIM: 0.7033467564434341 +[22:07:24.675] iteration 9300 [42.37 sec]: learning rate : 0.000500 loss : 0.419461 +[22:08:14.636] iteration 9400 [92.33 sec]: learning rate : 0.000500 loss : 0.369078 +[22:09:04.645] iteration 9500 [142.34 sec]: learning rate : 0.000500 loss : 0.491548 +[22:09:54.741] iteration 9600 [192.44 sec]: learning rate : 0.000500 loss : 0.466800 +[22:10:46.128] iteration 9700 [243.83 sec]: learning rate : 0.000500 loss : 0.345852 +[22:11:32.181] Epoch 16 Evaluation: +[22:12:16.859] average MSE: 0.04742912281541717 average PSNR: 26.28873713709677 average SSIM: 0.6914244649327967 +[22:12:21.069] iteration 9800 [4.18 sec]: learning rate : 0.000500 loss : 0.392733 +[22:13:11.148] iteration 9900 [54.26 sec]: learning rate : 0.000500 loss : 0.381699 +[22:14:01.116] iteration 10000 [104.23 sec]: learning rate : 0.000500 loss : 0.530212 +[22:14:51.233] iteration 10100 [154.35 sec]: learning rate : 0.000500 loss : 0.298771 +[22:15:41.204] iteration 10200 [204.32 sec]: learning rate : 0.000500 loss : 0.532272 +[22:16:31.216] iteration 10300 [254.33 sec]: learning rate : 0.000500 loss : 0.329047 +[22:17:05.602] Epoch 17 Evaluation: +[22:17:52.047] average MSE: 0.04389222931269823 average PSNR: 26.61817183808145 average SSIM: 0.7019826251574399 +[22:18:08.417] iteration 10400 [16.35 sec]: learning rate : 0.000500 loss : 0.380676 +[22:18:58.321] iteration 10500 [66.25 sec]: learning rate : 0.000500 loss : 0.330941 +[22:19:48.699] iteration 10600 [116.63 sec]: learning rate : 0.000500 loss : 0.364486 +[22:20:38.748] iteration 10700 [166.69 sec]: learning rate : 0.000500 loss : 0.411069 +[22:21:28.703] iteration 10800 [216.63 sec]: learning rate : 0.000500 loss : 0.313347 +[22:22:18.812] iteration 10900 [266.74 sec]: learning rate : 0.000500 loss : 0.363048 +[22:22:40.798] Epoch 18 Evaluation: +[22:23:26.505] average MSE: 0.049421589822476464 average PSNR: 26.102089778952433 average SSIM: 0.6924051571664505 +[22:23:54.838] iteration 11000 [28.30 sec]: learning rate : 0.000500 loss : 0.344809 +[22:24:44.805] iteration 11100 [78.27 sec]: learning rate : 0.000500 loss : 0.333594 +[22:25:35.151] iteration 11200 [128.62 sec]: learning rate : 0.000500 loss : 0.389535 +[22:26:25.072] iteration 11300 [178.54 sec]: learning rate : 0.000500 loss : 0.467124 +[22:27:15.082] iteration 11400 [228.55 sec]: learning rate : 0.000500 loss : 0.438954 +[22:28:05.108] iteration 11500 [278.57 sec]: learning rate : 0.000500 loss : 0.390876 +[22:28:15.095] Epoch 19 Evaluation: +[22:29:02.886] average MSE: 0.11912964841940817 average PSNR: 22.247746221089542 average SSIM: 0.617314373120688 +[22:29:43.092] iteration 11600 [40.18 sec]: learning rate : 0.000500 loss : 0.365978 +[22:30:33.163] iteration 11700 [90.25 sec]: learning rate : 0.000500 loss : 0.472499 +[22:31:23.153] iteration 11800 [140.24 sec]: learning rate : 0.000500 loss : 0.345234 +[22:32:13.148] iteration 11900 [190.24 sec]: learning rate : 0.000500 loss : 0.400064 +[22:33:03.264] iteration 12000 [240.35 sec]: learning rate : 0.000500 loss : 0.375674 +[22:33:51.207] Epoch 20 Evaluation: +[22:34:36.378] average MSE: 0.04348806252729702 average PSNR: 26.646225858868124 average SSIM: 0.7189695027676742 +[22:34:38.778] iteration 12100 [2.37 sec]: learning rate : 0.000500 loss : 0.403532 +[22:35:29.310] iteration 12200 [52.91 sec]: learning rate : 0.000500 loss : 0.348401 +[22:36:19.327] iteration 12300 [102.92 sec]: learning rate : 0.000500 loss : 0.566085 +[22:37:09.260] iteration 12400 [152.86 sec]: learning rate : 0.000500 loss : 0.310824 +[22:37:59.584] iteration 12500 [203.18 sec]: learning rate : 0.000500 loss : 0.376252 +[22:38:49.636] iteration 12600 [253.23 sec]: learning rate : 0.000500 loss : 0.938815 +[22:39:25.600] Epoch 21 Evaluation: +[22:40:13.410] average MSE: 1100.541288451885 average PSNR: -17.34640705318093 average SSIM: 6.168837621167061e-06 +[22:40:28.080] iteration 12700 [14.64 sec]: learning rate : 0.000500 loss : 0.676369 +[22:41:18.875] iteration 12800 [65.43 sec]: learning rate : 0.000500 loss : 0.544157 +[22:42:08.958] iteration 12900 [115.52 sec]: learning rate : 0.000500 loss : 0.625339 +[22:42:58.949] iteration 13000 [165.51 sec]: learning rate : 0.000500 loss : 0.532351 +[22:43:49.039] iteration 13100 [215.60 sec]: learning rate : 0.000500 loss : 0.427755 +[22:44:39.118] iteration 13200 [265.68 sec]: learning rate : 0.000500 loss : 0.364084 +[22:45:03.118] Epoch 22 Evaluation: +[22:45:50.263] average MSE: 0.043134186694685994 average PSNR: 26.71468836343044 average SSIM: 0.7289398639770789 +[22:46:16.463] iteration 13300 [26.17 sec]: learning rate : 0.000500 loss : 0.456377 +[22:47:07.070] iteration 13400 [76.78 sec]: learning rate : 0.000500 loss : 0.456865 +[22:47:58.053] iteration 13500 [127.76 sec]: learning rate : 0.000500 loss : 0.372322 +[22:48:48.177] iteration 13600 [177.89 sec]: learning rate : 0.000500 loss : 0.546932 +[22:49:38.279] iteration 13700 [227.99 sec]: learning rate : 0.000500 loss : 0.355562 +[22:50:28.272] iteration 13800 [277.98 sec]: learning rate : 0.000500 loss : 0.407863 +[22:50:40.367] Epoch 23 Evaluation: +[22:51:25.306] average MSE: 0.03603720088573622 average PSNR: 27.469498462115602 average SSIM: 0.7674314175144143 +[22:52:03.553] iteration 13900 [38.22 sec]: learning rate : 0.000500 loss : 0.424905 +[22:52:53.637] iteration 14000 [88.30 sec]: learning rate : 0.000500 loss : 0.583009 +[22:53:43.950] iteration 14100 [138.62 sec]: learning rate : 0.000500 loss : 0.414487 +[22:54:33.989] iteration 14200 [188.66 sec]: learning rate : 0.000500 loss : 0.328933 +[22:55:23.981] iteration 14300 [238.65 sec]: learning rate : 0.000500 loss : 0.373448 +[22:56:14.820] iteration 14400 [289.49 sec]: learning rate : 0.000500 loss : 0.401837 +[22:56:14.861] Epoch 24 Evaluation: +[22:57:00.407] average MSE: 0.04931789580219231 average PSNR: 26.102908110570766 average SSIM: 0.6823227037413261 +[22:57:50.721] iteration 14500 [50.29 sec]: learning rate : 0.000500 loss : 0.362840 +[22:58:40.791] iteration 14600 [100.36 sec]: learning rate : 0.000500 loss : 0.393686 +[22:59:31.160] iteration 14700 [150.73 sec]: learning rate : 0.000500 loss : 0.306047 +[23:00:21.256] iteration 14800 [200.82 sec]: learning rate : 0.000500 loss : 0.395021 +[23:01:11.237] iteration 14900 [250.80 sec]: learning rate : 0.000500 loss : 0.423229 +[23:01:49.340] Epoch 25 Evaluation: +[23:02:35.777] average MSE: 0.06428248805285032 average PSNR: 24.944096152239734 average SSIM: 0.6526348983572788 +[23:02:48.012] iteration 15000 [12.21 sec]: learning rate : 0.000500 loss : 0.397306 +[23:03:38.779] iteration 15100 [62.98 sec]: learning rate : 0.000500 loss : 0.365987 +[23:04:28.698] iteration 15200 [112.90 sec]: learning rate : 0.000500 loss : 0.383691 +[23:05:19.119] iteration 15300 [163.32 sec]: learning rate : 0.000500 loss : 0.347539 +[23:06:09.605] iteration 15400 [213.80 sec]: learning rate : 0.000500 loss : 0.329667 +[23:06:59.582] iteration 15500 [263.78 sec]: learning rate : 0.000500 loss : 0.468225 +[23:07:25.672] Epoch 26 Evaluation: +[23:08:10.327] average MSE: 0.07063886461884027 average PSNR: 24.52746785752841 average SSIM: 0.6611669083458785 +[23:08:34.549] iteration 15600 [24.20 sec]: learning rate : 0.000500 loss : 0.296202 +[23:09:24.677] iteration 15700 [74.32 sec]: learning rate : 0.000500 loss : 0.395204 +[23:10:14.688] iteration 15800 [124.33 sec]: learning rate : 0.000500 loss : 0.319227 +[23:11:05.318] iteration 15900 [174.96 sec]: learning rate : 0.000500 loss : 0.432971 +[23:11:55.704] iteration 16000 [225.35 sec]: learning rate : 0.000500 loss : 0.481210 +[23:12:45.628] iteration 16100 [275.27 sec]: learning rate : 0.000500 loss : 0.408516 +[23:12:59.619] Epoch 27 Evaluation: +[23:13:44.469] average MSE: 0.03861670477521635 average PSNR: 27.174667627046116 average SSIM: 0.7371818001686912 +[23:14:20.895] iteration 16200 [36.40 sec]: learning rate : 0.000500 loss : 0.274836 +[23:15:10.870] iteration 16300 [86.37 sec]: learning rate : 0.000500 loss : 0.359590 +[23:16:00.962] iteration 16400 [136.46 sec]: learning rate : 0.000500 loss : 0.401920 +[23:16:51.047] iteration 16500 [186.55 sec]: learning rate : 0.000500 loss : 0.310409 +[23:17:41.319] iteration 16600 [236.82 sec]: learning rate : 0.000500 loss : 0.400930 +[23:18:31.697] iteration 16700 [287.20 sec]: learning rate : 0.000500 loss : 0.389508 +[23:18:33.706] Epoch 28 Evaluation: +[23:19:18.891] average MSE: 0.054438908295018865 average PSNR: 25.69907643226203 average SSIM: 0.6765256905886282 +[23:20:07.185] iteration 16800 [48.27 sec]: learning rate : 0.000500 loss : 0.331532 +[23:20:57.124] iteration 16900 [98.21 sec]: learning rate : 0.000500 loss : 0.404239 +[23:21:47.186] iteration 17000 [148.27 sec]: learning rate : 0.000500 loss : 0.262291 +[23:22:37.695] iteration 17100 [198.78 sec]: learning rate : 0.000500 loss : 0.399216 +[23:23:27.630] iteration 17200 [248.71 sec]: learning rate : 0.000500 loss : 0.289830 +[23:24:08.247] Epoch 29 Evaluation: +[23:24:53.271] average MSE: 0.06265621952616524 average PSNR: 25.09563152570973 average SSIM: 0.6664277890781836 +[23:25:03.481] iteration 17300 [10.18 sec]: learning rate : 0.000500 loss : 0.308198 +[23:25:54.077] iteration 17400 [60.78 sec]: learning rate : 0.000500 loss : 0.376999 +[23:26:44.062] iteration 17500 [110.77 sec]: learning rate : 0.000500 loss : 0.316157 +[23:27:34.090] iteration 17600 [160.80 sec]: learning rate : 0.000500 loss : 0.256134 +[23:28:24.015] iteration 17700 [210.72 sec]: learning rate : 0.000500 loss : 0.412708 +[23:29:14.011] iteration 17800 [260.72 sec]: learning rate : 0.000500 loss : 0.652511 +[23:29:42.282] Epoch 30 Evaluation: +[23:30:28.488] average MSE: 0.0456172093641807 average PSNR: 26.46037542389843 average SSIM: 0.7031121843638907 +[23:30:50.886] iteration 17900 [22.37 sec]: learning rate : 0.000500 loss : 0.359682 +[23:31:41.463] iteration 18000 [72.95 sec]: learning rate : 0.000500 loss : 0.306026 +[23:32:31.562] iteration 18100 [123.05 sec]: learning rate : 0.000500 loss : 0.328921 +[23:33:22.148] iteration 18200 [173.63 sec]: learning rate : 0.000500 loss : 0.318704 +[23:34:12.190] iteration 18300 [223.68 sec]: learning rate : 0.000500 loss : 0.349187 +[23:35:02.309] iteration 18400 [273.79 sec]: learning rate : 0.000500 loss : 0.409681 +[23:35:18.323] Epoch 31 Evaluation: +[23:36:06.539] average MSE: 0.057776282373748954 average PSNR: 25.411040432250086 average SSIM: 0.6635069178359316 +[23:36:40.903] iteration 18500 [34.34 sec]: learning rate : 0.000500 loss : 0.410762 +[23:37:30.850] iteration 18600 [84.28 sec]: learning rate : 0.000500 loss : 0.394779 +[23:38:20.911] iteration 18700 [134.35 sec]: learning rate : 0.000500 loss : 0.277452 +[23:39:10.927] iteration 18800 [184.36 sec]: learning rate : 0.000500 loss : 0.737081 +[23:40:01.003] iteration 18900 [234.44 sec]: learning rate : 0.000500 loss : 0.423470 +[23:40:52.528] iteration 19000 [285.96 sec]: learning rate : 0.000500 loss : 0.387297 +[23:40:56.538] Epoch 32 Evaluation: +[23:41:41.917] average MSE: 0.048108676218362734 average PSNR: 26.247899440601326 average SSIM: 0.7034086765760703 +[23:42:28.540] iteration 19100 [46.60 sec]: learning rate : 0.000500 loss : 0.321678 +[23:43:18.650] iteration 19200 [96.71 sec]: learning rate : 0.000500 loss : 0.555446 +[23:44:08.671] iteration 19300 [146.73 sec]: learning rate : 0.000500 loss : 0.395450 +[23:44:58.610] iteration 19400 [196.67 sec]: learning rate : 0.000500 loss : 0.348239 +[23:45:48.624] iteration 19500 [246.68 sec]: learning rate : 0.000500 loss : 0.435682 +[23:46:30.557] Epoch 33 Evaluation: +[23:47:16.156] average MSE: 0.046995708941803986 average PSNR: 26.322308726158816 average SSIM: 0.7059541763957788 +[23:47:24.539] iteration 19600 [8.36 sec]: learning rate : 0.000500 loss : 0.421953 +[23:48:15.482] iteration 19700 [59.30 sec]: learning rate : 0.000500 loss : 0.442576 +[23:49:05.585] iteration 19800 [109.40 sec]: learning rate : 0.000500 loss : 0.462102 +[23:49:55.828] iteration 19900 [159.65 sec]: learning rate : 0.000500 loss : 0.317593 +[23:50:45.930] iteration 20000 [209.75 sec]: learning rate : 0.000125 loss : 0.476366 +[23:50:46.088] save model to model/FSMNet_m4raw_4x_lr5e-4_t5_new_kspace_time/iter_20000.pth +[23:51:36.163] iteration 20100 [259.98 sec]: learning rate : 0.000250 loss : 0.340755 +[23:52:06.133] Epoch 34 Evaluation: +[23:52:50.657] average MSE: 0.05363643028662837 average PSNR: 25.773458889039937 average SSIM: 0.701108139033757 +[23:53:10.846] iteration 20200 [20.16 sec]: learning rate : 0.000250 loss : 0.393342 +[23:54:01.431] iteration 20300 [70.75 sec]: learning rate : 0.000250 loss : 0.393297 +[23:54:51.526] iteration 20400 [120.84 sec]: learning rate : 0.000250 loss : 0.291281 +[23:55:41.947] iteration 20500 [171.27 sec]: learning rate : 0.000250 loss : 0.400188 +[23:56:32.071] iteration 20600 [221.39 sec]: learning rate : 0.000250 loss : 0.339016 +[23:57:22.151] iteration 20700 [271.47 sec]: learning rate : 0.000250 loss : 0.420579 +[23:57:40.144] Epoch 35 Evaluation: +[23:58:25.792] average MSE: 0.08407550763384129 average PSNR: 23.780197277159107 average SSIM: 0.6474304383906048 +[23:58:57.994] iteration 20800 [32.17 sec]: learning rate : 0.000250 loss : 0.377161 +[23:59:48.055] iteration 20900 [82.23 sec]: learning rate : 0.000250 loss : 0.285801 +[00:00:38.227] iteration 21000 [132.40 sec]: learning rate : 0.000250 loss : 0.364529 +[00:01:28.286] iteration 21100 [182.46 sec]: learning rate : 0.000250 loss : 1.522083 +[00:02:18.368] iteration 21200 [232.54 sec]: learning rate : 0.000250 loss : 0.358630 +[00:03:08.826] iteration 21300 [283.00 sec]: learning rate : 0.000250 loss : 0.310108 +[00:03:14.830] Epoch 36 Evaluation: +[00:04:01.380] average MSE: 0.06098810200028907 average PSNR: 25.187034086932663 average SSIM: 0.659008590549104 +[00:04:45.779] iteration 21400 [44.37 sec]: learning rate : 0.000250 loss : 0.381113 +[00:05:35.868] iteration 21500 [94.46 sec]: learning rate : 0.000250 loss : 0.407990 +[00:06:26.298] iteration 21600 [144.89 sec]: learning rate : 0.000250 loss : 0.424683 +[00:07:16.439] iteration 21700 [195.03 sec]: learning rate : 0.000250 loss : 0.326000 +[00:08:06.805] iteration 21800 [245.40 sec]: learning rate : 0.000250 loss : 0.325453 +[00:08:50.749] Epoch 37 Evaluation: +[00:09:35.659] average MSE: 0.05042717077814512 average PSNR: 26.005237600955702 average SSIM: 0.6894919133909213 +[00:09:41.915] iteration 21900 [6.23 sec]: learning rate : 0.000250 loss : 0.298694 +[00:10:32.602] iteration 22000 [56.92 sec]: learning rate : 0.000250 loss : 0.284620 +[00:11:22.699] iteration 22100 [107.01 sec]: learning rate : 0.000250 loss : 0.369264 +[00:12:13.080] iteration 22200 [157.39 sec]: learning rate : 0.000250 loss : 0.365468 +[00:13:03.192] iteration 22300 [207.51 sec]: learning rate : 0.000250 loss : 0.384624 +[00:13:53.160] iteration 22400 [257.47 sec]: learning rate : 0.000250 loss : 0.309964 +[00:14:25.266] Epoch 38 Evaluation: +[00:15:11.035] average MSE: 0.06878425035593125 average PSNR: 24.650355974954863 average SSIM: 0.6493306548390326 +[00:15:29.225] iteration 22500 [18.16 sec]: learning rate : 0.000250 loss : 0.368032 +[00:16:19.643] iteration 22600 [68.58 sec]: learning rate : 0.000250 loss : 0.332439 +[00:17:09.591] iteration 22700 [118.53 sec]: learning rate : 0.000250 loss : 0.329039 +[00:18:00.057] iteration 22800 [169.00 sec]: learning rate : 0.000250 loss : 0.238384 +[00:18:50.399] iteration 22900 [219.34 sec]: learning rate : 0.000250 loss : 0.385706 +[00:19:40.329] iteration 23000 [269.27 sec]: learning rate : 0.000250 loss : 0.371094 +[00:20:00.443] Epoch 39 Evaluation: +[00:20:45.698] average MSE: 0.049453198454166354 average PSNR: 26.089303056927893 average SSIM: 0.690659533682256 +[00:21:15.869] iteration 23100 [30.14 sec]: learning rate : 0.000250 loss : 0.364016 +[00:22:05.942] iteration 23200 [80.22 sec]: learning rate : 0.000250 loss : 0.307588 +[00:22:55.904] iteration 23300 [130.18 sec]: learning rate : 0.000250 loss : 0.368185 +[00:23:45.906] iteration 23400 [180.18 sec]: learning rate : 0.000250 loss : 0.300967 +[00:24:36.291] iteration 23500 [230.57 sec]: learning rate : 0.000250 loss : 0.330821 +[00:25:26.918] iteration 23600 [281.19 sec]: learning rate : 0.000250 loss : 0.378572 +[00:25:34.924] Epoch 40 Evaluation: +[00:26:21.910] average MSE: 0.07689654720621066 average PSNR: 24.177919621166314 average SSIM: 0.6441668308017254 +[00:27:04.214] iteration 23700 [42.28 sec]: learning rate : 0.000250 loss : 0.386985 +[00:27:54.119] iteration 23800 [92.18 sec]: learning rate : 0.000250 loss : 0.353327 +[00:28:44.141] iteration 23900 [142.20 sec]: learning rate : 0.000250 loss : 0.380342 +[00:29:34.230] iteration 24000 [192.29 sec]: learning rate : 0.000250 loss : 0.326385 +[00:30:24.851] iteration 24100 [242.92 sec]: learning rate : 0.000250 loss : 0.290424 +[00:31:10.995] Epoch 41 Evaluation: +[00:31:58.083] average MSE: 0.05339544528114652 average PSNR: 25.759915134805148 average SSIM: 0.6788931071194186 +[00:32:02.331] iteration 24200 [4.22 sec]: learning rate : 0.000250 loss : 0.410190 +[00:32:52.849] iteration 24300 [54.74 sec]: learning rate : 0.000250 loss : 0.337755 +[00:33:42.771] iteration 24400 [104.66 sec]: learning rate : 0.000250 loss : 0.360037 +[00:34:33.172] iteration 24500 [155.06 sec]: learning rate : 0.000250 loss : 0.343235 +[00:35:23.254] iteration 24600 [205.15 sec]: learning rate : 0.000250 loss : 0.437842 +[00:36:13.364] iteration 24700 [255.26 sec]: learning rate : 0.000250 loss : 0.352733 +[00:36:47.707] Epoch 42 Evaluation: +[00:37:34.816] average MSE: 0.05065560542381654 average PSNR: 25.98720540145151 average SSIM: 0.6868219784164156 +[00:37:51.191] iteration 24800 [16.35 sec]: learning rate : 0.000250 loss : 0.297047 +[00:38:41.135] iteration 24900 [66.29 sec]: learning rate : 0.000250 loss : 0.315245 +[00:39:31.162] iteration 25000 [116.32 sec]: learning rate : 0.000250 loss : 0.349308 +[00:40:21.744] iteration 25100 [166.90 sec]: learning rate : 0.000250 loss : 0.306565 +[00:41:11.775] iteration 25200 [216.93 sec]: learning rate : 0.000250 loss : 0.350597 +[00:42:01.842] iteration 25300 [267.00 sec]: learning rate : 0.000250 loss : 0.271898 +[00:42:24.250] Epoch 43 Evaluation: +[00:43:09.843] average MSE: 0.054769976184719675 average PSNR: 25.660525505424037 average SSIM: 0.6810190302100579 +[00:43:38.239] iteration 25400 [28.39 sec]: learning rate : 0.000250 loss : 0.759866 +[00:44:28.220] iteration 25500 [78.35 sec]: learning rate : 0.000250 loss : 0.334894 +[00:45:18.289] iteration 25600 [128.42 sec]: learning rate : 0.000250 loss : 0.424232 +[00:46:08.260] iteration 25700 [178.39 sec]: learning rate : 0.000250 loss : 0.339157 +[00:46:58.290] iteration 25800 [228.42 sec]: learning rate : 0.000250 loss : 0.321946 +[00:47:49.034] iteration 25900 [279.17 sec]: learning rate : 0.000250 loss : 0.385644 +[00:47:59.039] Epoch 44 Evaluation: +[00:48:46.844] average MSE: 0.13412511033236482 average PSNR: 21.76371659668972 average SSIM: 0.5957266436072108 +[00:49:27.139] iteration 26000 [40.27 sec]: learning rate : 0.000250 loss : 0.473562 +[00:50:17.245] iteration 26100 [90.37 sec]: learning rate : 0.000250 loss : 0.370466 +[00:51:07.410] iteration 26200 [140.54 sec]: learning rate : 0.000250 loss : 0.376936 +[00:51:57.737] iteration 26300 [190.87 sec]: learning rate : 0.000250 loss : 0.429509 +[00:52:47.857] iteration 26400 [240.99 sec]: learning rate : 0.000250 loss : 0.293807 +[00:53:35.858] Epoch 45 Evaluation: +[00:54:22.649] average MSE: 0.10203195075194557 average PSNR: 22.925179375200276 average SSIM: 0.5840993357960284 +[00:54:25.018] iteration 26500 [2.34 sec]: learning rate : 0.000250 loss : 0.298019 +[00:55:15.798] iteration 26600 [53.12 sec]: learning rate : 0.000250 loss : 0.301423 +[00:56:05.834] iteration 26700 [103.16 sec]: learning rate : 0.000250 loss : 0.472936 +[00:56:55.794] iteration 26800 [153.12 sec]: learning rate : 0.000250 loss : 0.303263 +[00:57:45.879] iteration 26900 [203.20 sec]: learning rate : 0.000250 loss : 0.330673 +[00:58:35.959] iteration 27000 [253.28 sec]: learning rate : 0.000250 loss : 0.314002 +[00:59:11.959] Epoch 46 Evaluation: +[00:59:56.800] average MSE: 0.07481623550349967 average PSNR: 24.285313184801257 average SSIM: 0.6452512847492221 +[01:00:11.001] iteration 27100 [14.17 sec]: learning rate : 0.000250 loss : 0.286461 +[01:01:01.898] iteration 27200 [65.07 sec]: learning rate : 0.000250 loss : 0.422048 +[01:01:52.014] iteration 27300 [115.19 sec]: learning rate : 0.000250 loss : 0.376854 +[01:02:42.313] iteration 27400 [165.51 sec]: learning rate : 0.000250 loss : 0.357947 +[01:03:32.464] iteration 27500 [215.64 sec]: learning rate : 0.000250 loss : 0.327937 +[01:04:22.572] iteration 27600 [265.75 sec]: learning rate : 0.000250 loss : 0.333094 +[01:04:46.581] Epoch 47 Evaluation: +[01:05:33.523] average MSE: 0.18206277948194416 average PSNR: 20.460350143792464 average SSIM: 0.5656348339532097 +[01:05:59.712] iteration 27700 [26.16 sec]: learning rate : 0.000250 loss : 0.318374 +[01:06:50.258] iteration 27800 [76.71 sec]: learning rate : 0.000250 loss : 0.347261 +[01:07:40.246] iteration 27900 [126.70 sec]: learning rate : 0.000250 loss : 0.288608 +[01:08:30.322] iteration 28000 [176.77 sec]: learning rate : 0.000250 loss : 0.309979 +[01:09:20.383] iteration 28100 [226.84 sec]: learning rate : 0.000250 loss : 0.391212 +[01:10:10.923] iteration 28200 [277.37 sec]: learning rate : 0.000250 loss : 0.263100 +[01:10:23.237] Epoch 48 Evaluation: +[01:11:09.939] average MSE: 0.05083533822666779 average PSNR: 25.9914266633985 average SSIM: 0.680810708317126 +[01:11:48.385] iteration 28300 [38.42 sec]: learning rate : 0.000250 loss : 0.289414 +[01:12:38.479] iteration 28400 [88.51 sec]: learning rate : 0.000250 loss : 0.395418 +[01:13:28.863] iteration 28500 [138.90 sec]: learning rate : 0.000250 loss : 0.369520 +[01:14:18.996] iteration 28600 [189.03 sec]: learning rate : 0.000250 loss : 0.317730 +[01:15:09.086] iteration 28700 [239.12 sec]: learning rate : 0.000250 loss : 0.275531 +[01:15:59.052] iteration 28800 [289.09 sec]: learning rate : 0.000250 loss : 0.275450 +[01:15:59.092] Epoch 49 Evaluation: +[01:16:47.047] average MSE: 0.05672174558472981 average PSNR: 25.495351173476273 average SSIM: 0.6752805787539454 +[01:17:37.402] iteration 28900 [50.33 sec]: learning rate : 0.000250 loss : 0.276530 +[01:18:27.902] iteration 29000 [100.83 sec]: learning rate : 0.000250 loss : 0.337796 +[01:19:18.646] iteration 29100 [151.57 sec]: learning rate : 0.000250 loss : 0.302838 +[01:20:08.680] iteration 29200 [201.61 sec]: learning rate : 0.000250 loss : 0.369344 +[01:20:58.622] iteration 29300 [251.55 sec]: learning rate : 0.000250 loss : 0.316791 +[01:21:36.674] Epoch 50 Evaluation: +[01:22:21.596] average MSE: 0.14313290185140726 average PSNR: 21.49850293686803 average SSIM: 0.5758881258571981 +[01:22:33.847] iteration 29400 [12.22 sec]: learning rate : 0.000250 loss : 0.383213 +[01:23:23.916] iteration 29500 [62.29 sec]: learning rate : 0.000250 loss : 0.365002 +[01:24:13.831] iteration 29600 [112.21 sec]: learning rate : 0.000250 loss : 0.356436 +[01:25:04.192] iteration 29700 [162.57 sec]: learning rate : 0.000250 loss : 0.378173 +[01:25:54.748] iteration 29800 [213.13 sec]: learning rate : 0.000250 loss : 0.306540 +[01:26:44.683] iteration 29900 [263.06 sec]: learning rate : 0.000250 loss : 0.377902 +[01:27:10.666] Epoch 51 Evaluation: +[01:27:57.021] average MSE: 0.09995349529404621 average PSNR: 23.029031498732376 average SSIM: 0.6043475436453128 +[01:28:21.377] iteration 30000 [24.33 sec]: learning rate : 0.000250 loss : 0.313928 +[01:29:11.298] iteration 30100 [74.25 sec]: learning rate : 0.000250 loss : 0.301947 +[01:30:01.332] iteration 30200 [124.29 sec]: learning rate : 0.000250 loss : 0.268314 +[01:30:51.834] iteration 30300 [174.79 sec]: learning rate : 0.000250 loss : 0.308663 +[01:31:41.784] iteration 30400 [224.74 sec]: learning rate : 0.000250 loss : 0.320603 +[01:32:31.851] iteration 30500 [274.80 sec]: learning rate : 0.000250 loss : 0.342151 +[01:32:45.860] Epoch 52 Evaluation: +[01:33:31.010] average MSE: 0.12366107730088738 average PSNR: 22.119408157643015 average SSIM: 0.6267877279954654 +[01:34:07.338] iteration 30600 [36.30 sec]: learning rate : 0.000250 loss : 0.254359 +[01:34:57.254] iteration 30700 [86.22 sec]: learning rate : 0.000250 loss : 0.362260 +[01:35:47.350] iteration 30800 [136.32 sec]: learning rate : 0.000250 loss : 0.361982 +[01:36:37.867] iteration 30900 [186.83 sec]: learning rate : 0.000250 loss : 0.356826 +[01:37:28.399] iteration 31000 [237.36 sec]: learning rate : 0.000250 loss : 0.332070 +[01:38:18.429] iteration 31100 [287.39 sec]: learning rate : 0.000250 loss : 0.397453 +[01:38:20.436] Epoch 53 Evaluation: +[01:39:05.394] average MSE: 0.11323660378349282 average PSNR: 22.486206790844058 average SSIM: 0.6171740198907668 +[01:39:53.614] iteration 31200 [48.19 sec]: learning rate : 0.000250 loss : 0.314397 +[01:40:44.044] iteration 31300 [98.62 sec]: learning rate : 0.000250 loss : 0.380595 +[01:41:34.147] iteration 31400 [148.73 sec]: learning rate : 0.000250 loss : 0.242409 +[01:42:24.135] iteration 31500 [198.71 sec]: learning rate : 0.000250 loss : 10.301101 +[01:43:14.390] iteration 31600 [248.97 sec]: learning rate : 0.000250 loss : 0.373495 +[01:43:54.433] Epoch 54 Evaluation: +[01:44:39.296] average MSE: 0.07850261207241792 average PSNR: 24.085153515116744 average SSIM: 0.6348277858389375 +[01:44:49.505] iteration 31700 [10.18 sec]: learning rate : 0.000250 loss : 0.345260 +[01:45:39.878] iteration 31800 [60.56 sec]: learning rate : 0.000250 loss : 0.360263 +[01:46:29.985] iteration 31900 [110.66 sec]: learning rate : 0.000250 loss : 0.332957 +[01:47:20.016] iteration 32000 [160.69 sec]: learning rate : 0.000250 loss : 0.303018 +[01:48:10.463] iteration 32100 [211.14 sec]: learning rate : 0.000250 loss : 0.346975 +[01:49:00.520] iteration 32200 [261.20 sec]: learning rate : 0.000250 loss : 0.308715 +[01:49:29.092] Epoch 55 Evaluation: +[01:50:15.775] average MSE: 0.07328892790847401 average PSNR: 24.377040784115074 average SSIM: 0.6320083979156734 +[01:50:38.184] iteration 32300 [22.38 sec]: learning rate : 0.000250 loss : 0.340689 +[01:51:28.176] iteration 32400 [72.38 sec]: learning rate : 0.000250 loss : 0.328887 +[01:52:18.294] iteration 32500 [122.49 sec]: learning rate : 0.000250 loss : 0.303582 +[01:53:08.323] iteration 32600 [172.52 sec]: learning rate : 0.000250 loss : 0.294413 +[01:53:58.423] iteration 32700 [222.62 sec]: learning rate : 0.000250 loss : 0.391229 +[01:54:48.957] iteration 32800 [273.16 sec]: learning rate : 0.000250 loss : 0.335249 +[01:55:04.972] Epoch 56 Evaluation: +[01:55:52.808] average MSE: 0.3003683880590477 average PSNR: 18.326930711240085 average SSIM: 0.4640629879490061 +[01:56:27.003] iteration 32900 [34.17 sec]: learning rate : 0.000250 loss : 0.375772 +[01:57:17.096] iteration 33000 [84.26 sec]: learning rate : 0.000250 loss : 0.365938 +[01:58:07.128] iteration 33100 [134.29 sec]: learning rate : 0.000250 loss : 0.276245 +[01:58:57.097] iteration 33200 [184.26 sec]: learning rate : 0.000250 loss : 0.333378 +[01:59:47.165] iteration 33300 [234.33 sec]: learning rate : 0.000250 loss : 0.345506 +[02:00:37.103] iteration 33400 [284.26 sec]: learning rate : 0.000250 loss : 0.378781 +[02:00:41.228] Epoch 57 Evaluation: +[02:01:27.226] average MSE: 0.1136560446354602 average PSNR: 22.469549923562646 average SSIM: 0.6280099459282612 +[02:02:13.427] iteration 33500 [46.18 sec]: learning rate : 0.000250 loss : 0.327855 +[02:03:04.102] iteration 33600 [96.85 sec]: learning rate : 0.000250 loss : 0.435017 +[02:03:54.532] iteration 33700 [147.28 sec]: learning rate : 0.000250 loss : 0.418611 +[02:04:44.664] iteration 33800 [197.41 sec]: learning rate : 0.000250 loss : 0.289424 +[02:05:34.815] iteration 33900 [247.56 sec]: learning rate : 0.000250 loss : 0.407817 +[02:06:16.836] Epoch 58 Evaluation: +[02:07:04.384] average MSE: 0.058726862957999434 average PSNR: 25.358960739461914 average SSIM: 0.6626238633246703 +[02:07:12.832] iteration 34000 [8.42 sec]: learning rate : 0.000250 loss : 0.370972 +[02:08:03.024] iteration 34100 [58.61 sec]: learning rate : 0.000250 loss : 0.414609 +[02:08:53.071] iteration 34200 [108.66 sec]: learning rate : 0.000250 loss : 0.310613 +[02:09:43.063] iteration 34300 [158.65 sec]: learning rate : 0.000250 loss : 0.318708 +[02:10:33.587] iteration 34400 [209.18 sec]: learning rate : 0.000250 loss : 0.300961 +[02:11:23.518] iteration 34500 [259.11 sec]: learning rate : 0.000250 loss : 0.290298 +[02:11:53.628] Epoch 59 Evaluation: +[02:12:40.637] average MSE: 0.07930383586172213 average PSNR: 24.03995953525203 average SSIM: 0.6326001213654942 +[02:13:00.841] iteration 34600 [20.18 sec]: learning rate : 0.000250 loss : 0.285044 +[02:13:51.303] iteration 34700 [70.64 sec]: learning rate : 0.000250 loss : 0.359949 +[02:14:41.326] iteration 34800 [120.66 sec]: learning rate : 0.000250 loss : 0.229714 +[02:15:31.442] iteration 34900 [170.78 sec]: learning rate : 0.000250 loss : 0.387281 +[02:16:21.535] iteration 35000 [220.87 sec]: learning rate : 0.000250 loss : 0.289397 +[02:17:11.521] iteration 35100 [270.86 sec]: learning rate : 0.000250 loss : 0.363480 +[02:17:29.502] Epoch 60 Evaluation: +[02:18:16.350] average MSE: 0.06169364347836912 average PSNR: 25.120798556706188 average SSIM: 0.6537973263382731 +[02:18:48.676] iteration 35200 [32.30 sec]: learning rate : 0.000250 loss : 0.285336 +[02:19:39.272] iteration 35300 [82.90 sec]: learning rate : 0.000250 loss : 0.281581 +[02:20:29.221] iteration 35400 [132.85 sec]: learning rate : 0.000250 loss : 0.345084 +[02:21:19.763] iteration 35500 [183.39 sec]: learning rate : 0.000250 loss : 0.352206 +[02:22:09.791] iteration 35600 [233.41 sec]: learning rate : 0.000250 loss : 0.395330 +[02:22:59.827] iteration 35700 [283.45 sec]: learning rate : 0.000250 loss : 0.313528 +[02:23:05.836] Epoch 61 Evaluation: +[02:23:51.235] average MSE: 0.06427767957506049 average PSNR: 24.94844971344901 average SSIM: 0.6620950274675174 +[02:24:35.606] iteration 35800 [44.35 sec]: learning rate : 0.000250 loss : 0.363590 +[02:25:26.045] iteration 35900 [94.78 sec]: learning rate : 0.000250 loss : 0.355536 +[02:26:16.666] iteration 36000 [145.41 sec]: learning rate : 0.000250 loss : 0.451601 +[02:27:06.763] iteration 36100 [195.50 sec]: learning rate : 0.000250 loss : 0.300391 +[02:27:56.740] iteration 36200 [245.48 sec]: learning rate : 0.000250 loss : 0.300415 +[02:28:40.827] Epoch 62 Evaluation: +[02:29:27.577] average MSE: 0.09108114133537634 average PSNR: 23.420711158144794 average SSIM: 0.6313519508104227 +[02:29:33.791] iteration 36300 [6.19 sec]: learning rate : 0.000250 loss : 0.332797 +[02:30:24.018] iteration 36400 [56.42 sec]: learning rate : 0.000250 loss : 0.263331 +[02:31:13.995] iteration 36500 [106.39 sec]: learning rate : 0.000250 loss : 0.267029 +[02:32:04.528] iteration 36600 [156.93 sec]: learning rate : 0.000250 loss : 0.397284 +[02:32:54.571] iteration 36700 [206.97 sec]: learning rate : 0.000250 loss : 0.300613 +[02:33:44.938] iteration 36800 [257.34 sec]: learning rate : 0.000250 loss : 0.298903 +[02:34:17.160] Epoch 63 Evaluation: +[02:35:01.898] average MSE: 0.04781957548801275 average PSNR: 26.23806978432122 average SSIM: 0.6922315281272949 +[02:35:20.090] iteration 36900 [18.17 sec]: learning rate : 0.000250 loss : 0.349510 +[02:36:10.161] iteration 37000 [68.24 sec]: learning rate : 0.000250 loss : 0.312354 +[02:37:00.135] iteration 37100 [118.21 sec]: learning rate : 0.000250 loss : 0.299600 +[02:37:50.617] iteration 37200 [168.69 sec]: learning rate : 0.000250 loss : 0.253810 +[02:38:40.581] iteration 37300 [218.68 sec]: learning rate : 0.000250 loss : 0.340352 +[02:39:30.969] iteration 37400 [269.04 sec]: learning rate : 0.000250 loss : 0.361020 +[02:39:50.958] Epoch 64 Evaluation: +[02:40:36.090] average MSE: 0.10986846370188619 average PSNR: 22.611816320499976 average SSIM: 0.6241417878778164 +[02:41:06.460] iteration 37500 [30.34 sec]: learning rate : 0.000250 loss : 0.351743 +[02:41:56.408] iteration 37600 [80.29 sec]: learning rate : 0.000250 loss : 0.335304 +[02:42:46.457] iteration 37700 [130.34 sec]: learning rate : 0.000250 loss : 0.407568 +[02:43:36.915] iteration 37800 [180.80 sec]: learning rate : 0.000250 loss : 0.328187 +[02:44:26.924] iteration 37900 [230.81 sec]: learning rate : 0.000250 loss : 0.327130 +[02:45:17.040] iteration 38000 [280.92 sec]: learning rate : 0.000250 loss : 0.306801 +[02:45:25.049] Epoch 65 Evaluation: +[02:46:11.569] average MSE: 0.09400503802140664 average PSNR: 23.281657901350727 average SSIM: 0.6230004886760676 +[02:46:53.727] iteration 38100 [42.13 sec]: learning rate : 0.000250 loss : 0.368209 +[02:47:44.185] iteration 38200 [92.59 sec]: learning rate : 0.000250 loss : 0.384090 +[02:48:34.538] iteration 38300 [142.94 sec]: learning rate : 0.000250 loss : 0.365540 +[02:49:24.737] iteration 38400 [193.16 sec]: learning rate : 0.000250 loss : 0.346750 +[02:50:14.773] iteration 38500 [243.18 sec]: learning rate : 0.000250 loss : 0.325020 +[02:51:00.867] Epoch 66 Evaluation: +[02:51:45.869] average MSE: 0.1053361192976287 average PSNR: 22.80490681299151 average SSIM: 0.6111152526007553 +[02:51:50.084] iteration 38600 [4.19 sec]: learning rate : 0.000250 loss : 0.360250 +[02:52:40.006] iteration 38700 [54.11 sec]: learning rate : 0.000250 loss : 0.392500 +[02:53:30.075] iteration 38800 [104.18 sec]: learning rate : 0.000250 loss : 0.573030 +[02:54:20.086] iteration 38900 [154.19 sec]: learning rate : 0.000250 loss : 0.308260 +[02:55:10.011] iteration 39000 [204.12 sec]: learning rate : 0.000250 loss : 0.443056 +[02:56:00.909] iteration 39100 [255.01 sec]: learning rate : 0.000250 loss : 0.289210 +[02:56:35.326] Epoch 67 Evaluation: +[02:57:19.887] average MSE: 0.10613456056517098 average PSNR: 22.755258980482736 average SSIM: 0.6243211688532995 +[02:57:36.162] iteration 39200 [16.25 sec]: learning rate : 0.000250 loss : 0.293775 +[02:58:26.311] iteration 39300 [66.40 sec]: learning rate : 0.000250 loss : 0.363964 +[02:59:16.413] iteration 39400 [116.50 sec]: learning rate : 0.000250 loss : 0.331483 +[03:00:06.430] iteration 39500 [166.52 sec]: learning rate : 0.000250 loss : 0.290057 +[03:00:56.543] iteration 39600 [216.63 sec]: learning rate : 0.000250 loss : 0.319517 +[03:01:47.027] iteration 39700 [267.11 sec]: learning rate : 0.000250 loss : 0.267029 +[03:02:09.045] Epoch 68 Evaluation: +[03:02:55.533] average MSE: 0.05951057427706202 average PSNR: 25.318791913582626 average SSIM: 0.6668262315699661 +[03:03:24.185] iteration 39800 [28.63 sec]: learning rate : 0.000250 loss : 0.252743 +[03:04:14.333] iteration 39900 [78.77 sec]: learning rate : 0.000250 loss : 0.302436 +[03:05:04.476] iteration 40000 [128.92 sec]: learning rate : 0.000063 loss : 0.310683 +[03:05:04.637] save model to model/FSMNet_m4raw_4x_lr5e-4_t5_new_kspace_time/iter_40000.pth +[03:05:55.074] iteration 40100 [179.51 sec]: learning rate : 0.000125 loss : 0.377971 +[03:06:45.107] iteration 40200 [229.55 sec]: learning rate : 0.000125 loss : 0.367970 +[03:07:35.570] iteration 40300 [280.04 sec]: learning rate : 0.000125 loss : 0.347702 +[03:07:45.589] Epoch 69 Evaluation: +[03:08:32.468] average MSE: 0.12982990106939227 average PSNR: 21.898717931104503 average SSIM: 0.6227870974757455 +[03:09:12.634] iteration 40400 [40.14 sec]: learning rate : 0.000125 loss : 0.366432 +[03:10:02.791] iteration 40500 [90.30 sec]: learning rate : 0.000125 loss : 0.362712 +[03:10:53.264] iteration 40600 [140.77 sec]: learning rate : 0.000125 loss : 0.342472 +[03:11:43.367] iteration 40700 [190.87 sec]: learning rate : 0.000125 loss : 0.398605 +[03:12:33.437] iteration 40800 [240.94 sec]: learning rate : 0.000125 loss : 0.289235 +[03:13:21.377] Epoch 70 Evaluation: +[03:14:08.294] average MSE: 0.14236869282310835 average PSNR: 21.512281774817303 average SSIM: 0.6121495025468262 +[03:14:10.512] iteration 40900 [2.19 sec]: learning rate : 0.000125 loss : 0.299362 +[03:15:00.821] iteration 41000 [52.50 sec]: learning rate : 0.000125 loss : 0.281291 +[03:15:50.830] iteration 41100 [102.51 sec]: learning rate : 0.000125 loss : 0.406997 +[03:16:40.747] iteration 41200 [152.43 sec]: learning rate : 0.000125 loss : 0.252443 +[03:17:30.787] iteration 41300 [202.47 sec]: learning rate : 0.000125 loss : 0.337661 +[03:18:21.219] iteration 41400 [252.90 sec]: learning rate : 0.000125 loss : 0.345482 +[03:18:57.173] Epoch 71 Evaluation: +[03:19:43.547] average MSE: 0.10278778082958409 average PSNR: 22.89704863484758 average SSIM: 0.6267003173563044 +[03:19:57.747] iteration 41500 [14.18 sec]: learning rate : 0.000125 loss : 0.286969 +[03:20:47.813] iteration 41600 [64.24 sec]: learning rate : 0.000125 loss : 0.313968 +[03:21:37.737] iteration 41700 [114.16 sec]: learning rate : 0.000125 loss : 0.291949 +[03:22:27.756] iteration 41800 [164.18 sec]: learning rate : 0.000125 loss : 0.360807 +[03:23:18.040] iteration 41900 [214.47 sec]: learning rate : 0.000125 loss : 0.406006 +[03:24:07.966] iteration 42000 [264.39 sec]: learning rate : 0.000125 loss : 0.328086 +[03:24:32.069] Epoch 72 Evaluation: +[03:25:17.825] average MSE: 0.05933495365981386 average PSNR: 25.303488471649768 average SSIM: 0.6733101426327704 +[03:25:44.809] iteration 42100 [26.96 sec]: learning rate : 0.000125 loss : 0.303176 +[03:26:34.900] iteration 42200 [77.05 sec]: learning rate : 0.000125 loss : 0.383884 +[03:27:24.851] iteration 42300 [127.00 sec]: learning rate : 0.000125 loss : 0.324675 +[03:28:14.883] iteration 42400 [177.03 sec]: learning rate : 0.000125 loss : 0.266711 +[03:29:04.919] iteration 42500 [227.07 sec]: learning rate : 0.000125 loss : 0.287845 +[03:29:54.859] iteration 42600 [277.01 sec]: learning rate : 0.000125 loss : 0.265497 +[03:30:06.851] Epoch 73 Evaluation: +[03:30:52.392] average MSE: 0.11951223996802536 average PSNR: 22.256856800543538 average SSIM: 0.632831206560476 +[03:31:30.758] iteration 42700 [38.34 sec]: learning rate : 0.000125 loss : 0.312884 +[03:32:21.512] iteration 42800 [89.09 sec]: learning rate : 0.000125 loss : 0.396234 +[03:33:11.814] iteration 42900 [139.42 sec]: learning rate : 0.000125 loss : 0.332813 +[03:34:01.974] iteration 43000 [189.56 sec]: learning rate : 0.000125 loss : 0.279696 +[03:34:51.967] iteration 43100 [239.55 sec]: learning rate : 0.000125 loss : 0.296688 +[03:35:41.996] iteration 43200 [289.58 sec]: learning rate : 0.000125 loss : 0.319528 +[03:35:42.035] Epoch 74 Evaluation: +[03:36:28.929] average MSE: 0.06186289186688732 average PSNR: 25.12198920251703 average SSIM: 0.655283394098588 +[03:37:19.230] iteration 43300 [50.27 sec]: learning rate : 0.000125 loss : 0.347496 +[03:38:09.531] iteration 43400 [100.58 sec]: learning rate : 0.000125 loss : 0.331120 +[03:38:59.597] iteration 43500 [150.64 sec]: learning rate : 0.000125 loss : 0.298066 +[03:39:49.713] iteration 43600 [200.76 sec]: learning rate : 0.000125 loss : 0.353790 +[03:40:40.141] iteration 43700 [251.19 sec]: learning rate : 0.000125 loss : 0.302220 +[03:41:18.683] Epoch 75 Evaluation: +[03:42:03.941] average MSE: 0.12665215674388924 average PSNR: 21.99784023582335 average SSIM: 0.6111651293029824 +[03:42:16.152] iteration 43800 [12.19 sec]: learning rate : 0.000125 loss : 0.316714 +[03:43:06.227] iteration 43900 [62.26 sec]: learning rate : 0.000125 loss : 0.347994 +[03:43:56.588] iteration 44000 [112.62 sec]: learning rate : 0.000125 loss : 0.319603 +[03:44:46.834] iteration 44100 [162.87 sec]: learning rate : 0.000125 loss : 0.325318 +[03:45:36.859] iteration 44200 [212.89 sec]: learning rate : 0.000125 loss : 0.293233 +[03:46:26.987] iteration 44300 [263.02 sec]: learning rate : 0.000125 loss : 0.373459 +[03:46:53.008] Epoch 76 Evaluation: +[03:47:40.468] average MSE: 0.10959389951826214 average PSNR: 22.61813077878462 average SSIM: 0.6388380426816104 +[03:48:04.810] iteration 44400 [24.32 sec]: learning rate : 0.000125 loss : 0.249050 +[03:48:55.207] iteration 44500 [74.71 sec]: learning rate : 0.000125 loss : 0.290447 +[03:49:45.291] iteration 44600 [124.80 sec]: learning rate : 0.000125 loss : 0.295088 +[03:50:36.182] iteration 44700 [175.69 sec]: learning rate : 0.000125 loss : 0.330858 +[03:51:26.129] iteration 44800 [225.64 sec]: learning rate : 0.000125 loss : 0.335296 +[03:52:16.215] iteration 44900 [275.72 sec]: learning rate : 0.000125 loss : 0.359057 +[03:52:30.226] Epoch 77 Evaluation: +[03:53:16.433] average MSE: 0.10130554957512583 average PSNR: 22.95504371385291 average SSIM: 0.6264161047851154 +[03:53:52.747] iteration 45000 [36.29 sec]: learning rate : 0.000125 loss : 0.238162 +[03:54:42.666] iteration 45100 [86.21 sec]: learning rate : 0.000125 loss : 0.337834 +[03:55:32.729] iteration 45200 [136.27 sec]: learning rate : 0.000125 loss : 0.349001 +[03:56:23.696] iteration 45300 [187.24 sec]: learning rate : 0.000125 loss : 0.288518 +[03:57:13.710] iteration 45400 [237.25 sec]: learning rate : 0.000125 loss : 0.312014 +[03:58:03.811] iteration 45500 [287.35 sec]: learning rate : 0.000125 loss : 0.396794 +[03:58:05.821] Epoch 78 Evaluation: +[03:58:51.017] average MSE: 0.11328278001499233 average PSNR: 22.478232052830002 average SSIM: 0.6334696197365632 +[03:59:39.521] iteration 45600 [48.48 sec]: learning rate : 0.000125 loss : 0.334613 +[04:00:29.610] iteration 45700 [98.58 sec]: learning rate : 0.000125 loss : 0.303893 +[04:01:19.677] iteration 45800 [148.63 sec]: learning rate : 0.000125 loss : 0.268922 +[04:02:10.057] iteration 45900 [199.02 sec]: learning rate : 0.000125 loss : 0.314406 +[04:03:00.198] iteration 46000 [249.16 sec]: learning rate : 0.000125 loss : 0.329119 +[04:03:40.548] Epoch 79 Evaluation: +[04:04:27.194] average MSE: 0.12516101297997495 average PSNR: 22.0645955451474 average SSIM: 0.6353611678879627 +[04:04:37.398] iteration 46100 [10.18 sec]: learning rate : 0.000125 loss : 0.303601 +[04:05:27.310] iteration 46200 [60.09 sec]: learning rate : 0.000125 loss : 0.301698 +[04:06:17.387] iteration 46300 [110.17 sec]: learning rate : 0.000125 loss : 0.298398 +[04:07:07.442] iteration 46400 [160.22 sec]: learning rate : 0.000125 loss : 0.299701 +[04:07:57.764] iteration 46500 [210.55 sec]: learning rate : 0.000125 loss : 0.313583 +[04:08:48.150] iteration 46600 [260.93 sec]: learning rate : 0.000125 loss : 0.352271 +[04:09:16.120] Epoch 80 Evaluation: +[04:10:00.607] average MSE: 0.12010742816806136 average PSNR: 22.23124196679211 average SSIM: 0.6285272981720097 +[04:10:22.786] iteration 46700 [22.15 sec]: learning rate : 0.000125 loss : 0.392188 +[04:11:13.202] iteration 46800 [72.57 sec]: learning rate : 0.000125 loss : 0.356275 +[04:12:03.246] iteration 46900 [122.61 sec]: learning rate : 0.000125 loss : 0.292838 +[04:12:53.174] iteration 47000 [172.54 sec]: learning rate : 0.000125 loss : 0.288817 +[04:13:43.177] iteration 47100 [222.54 sec]: learning rate : 0.000125 loss : 0.368662 +[04:14:33.503] iteration 47200 [272.87 sec]: learning rate : 0.000125 loss : 0.314899 +[04:14:49.480] Epoch 81 Evaluation: +[04:15:34.920] average MSE: 0.06494067301171595 average PSNR: 24.889978850898714 average SSIM: 0.66223596415262 +[04:16:09.081] iteration 47300 [34.13 sec]: learning rate : 0.000125 loss : 0.364549 +[04:16:59.733] iteration 47400 [84.79 sec]: learning rate : 0.000125 loss : 0.355182 +[04:17:49.780] iteration 47500 [134.83 sec]: learning rate : 0.000125 loss : 0.265486 +[04:18:40.024] iteration 47600 [185.08 sec]: learning rate : 0.000125 loss : 0.319410 +[04:19:30.106] iteration 47700 [235.16 sec]: learning rate : 0.000125 loss : 0.385805 +[04:20:20.314] iteration 47800 [285.37 sec]: learning rate : 0.000125 loss : 0.290953 +[04:20:24.328] Epoch 82 Evaluation: +[04:21:10.371] average MSE: 0.11520134667354412 average PSNR: 22.39950936341879 average SSIM: 0.62619263742356 +[04:21:56.740] iteration 47900 [46.34 sec]: learning rate : 0.000125 loss : 0.366710 +[04:22:46.828] iteration 48000 [96.43 sec]: learning rate : 0.000125 loss : 0.405922 +[04:23:36.834] iteration 48100 [146.44 sec]: learning rate : 0.000125 loss : 0.380251 +[04:24:26.967] iteration 48200 [196.57 sec]: learning rate : 0.000125 loss : 0.302781 +[04:25:17.019] iteration 48300 [246.62 sec]: learning rate : 0.000125 loss : 0.389444 +[04:25:59.930] Epoch 83 Evaluation: +[04:26:45.538] average MSE: 0.10673518029689133 average PSNR: 22.733946631296647 average SSIM: 0.6298458613877866 +[04:26:53.770] iteration 48400 [8.21 sec]: learning rate : 0.000125 loss : 0.353913 +[04:27:43.867] iteration 48500 [58.30 sec]: learning rate : 0.000125 loss : 0.350068 +[04:28:33.800] iteration 48600 [108.24 sec]: learning rate : 0.000125 loss : 0.276563 +[04:29:23.809] iteration 48700 [158.25 sec]: learning rate : 0.000125 loss : 0.269257 +[04:30:13.848] iteration 48800 [208.29 sec]: learning rate : 0.000125 loss : 0.335686 +[04:31:03.775] iteration 48900 [258.21 sec]: learning rate : 0.000125 loss : 0.287062 +[04:31:33.846] Epoch 84 Evaluation: +[04:32:20.288] average MSE: 0.08855468019444042 average PSNR: 23.541184225589976 average SSIM: 0.6276288566751397 +[04:32:40.576] iteration 49000 [20.26 sec]: learning rate : 0.000125 loss : 0.331212 +[04:33:31.027] iteration 49100 [70.71 sec]: learning rate : 0.000125 loss : 0.340855 +[04:34:20.988] iteration 49200 [120.67 sec]: learning rate : 0.000125 loss : 0.250910 +[04:35:11.407] iteration 49300 [171.09 sec]: learning rate : 0.000125 loss : 0.489424 +[04:36:01.422] iteration 49400 [221.10 sec]: learning rate : 0.000125 loss : 0.263405 +[04:36:51.346] iteration 49500 [271.03 sec]: learning rate : 0.000125 loss : 0.319590 +[04:37:09.325] Epoch 85 Evaluation: +[04:37:53.745] average MSE: 0.0472288483744437 average PSNR: 26.288486582413167 average SSIM: 0.69662366858267 +[04:38:26.575] iteration 49600 [32.80 sec]: learning rate : 0.000125 loss : 0.275226 +[04:39:16.642] iteration 49700 [82.90 sec]: learning rate : 0.000125 loss : 0.241152 +[04:40:06.627] iteration 49800 [132.86 sec]: learning rate : 0.000125 loss : 0.338060 +[04:40:56.734] iteration 49900 [182.96 sec]: learning rate : 0.000125 loss : 0.350133 +[04:41:47.072] iteration 50000 [233.30 sec]: learning rate : 0.000125 loss : 0.302382 +[04:42:37.154] iteration 50100 [283.38 sec]: learning rate : 0.000125 loss : 0.316246 +[04:42:43.162] Epoch 86 Evaluation: +[04:43:29.013] average MSE: 0.05279999699621796 average PSNR: 25.796504284355514 average SSIM: 0.6845501771979102 +[04:44:13.538] iteration 50200 [44.50 sec]: learning rate : 0.000125 loss : 0.385209 +[04:45:04.121] iteration 50300 [95.08 sec]: learning rate : 0.000125 loss : 0.344237 +[04:45:54.275] iteration 50400 [145.23 sec]: learning rate : 0.000125 loss : 0.367499 +[04:46:44.414] iteration 50500 [195.37 sec]: learning rate : 0.000125 loss : 0.299522 +[04:47:34.436] iteration 50600 [245.40 sec]: learning rate : 0.000125 loss : 0.297428 +[04:48:18.575] Epoch 87 Evaluation: +[04:49:05.937] average MSE: 0.054675899472810506 average PSNR: 25.642056186948974 average SSIM: 0.678195320618928 +[04:49:12.137] iteration 50700 [6.17 sec]: learning rate : 0.000125 loss : 0.357126 +[04:50:02.220] iteration 50800 [56.28 sec]: learning rate : 0.000125 loss : 0.350063 +[04:50:52.637] iteration 50900 [106.67 sec]: learning rate : 0.000125 loss : 0.299449 +[04:51:42.654] iteration 51000 [156.69 sec]: learning rate : 0.000125 loss : 0.396862 +[04:52:32.998] iteration 51100 [207.03 sec]: learning rate : 0.000125 loss : 0.336886 +[04:53:23.018] iteration 51200 [257.06 sec]: learning rate : 0.000125 loss : 0.330393 +[04:53:54.986] Epoch 88 Evaluation: +[04:54:39.891] average MSE: 0.09315040298717747 average PSNR: 23.31709766685768 average SSIM: 0.6344490520365292 +[04:54:58.246] iteration 51300 [18.33 sec]: learning rate : 0.000125 loss : 0.356632 +[04:55:48.150] iteration 51400 [68.23 sec]: learning rate : 0.000125 loss : 0.278523 +[04:56:39.038] iteration 51500 [119.12 sec]: learning rate : 0.000125 loss : 0.356871 +[04:57:29.076] iteration 51600 [169.16 sec]: learning rate : 0.000125 loss : 0.237516 +[04:58:19.013] iteration 51700 [219.10 sec]: learning rate : 0.000125 loss : 0.336172 +[04:59:09.061] iteration 51800 [269.14 sec]: learning rate : 0.000125 loss : 0.324677 +[04:59:29.041] Epoch 89 Evaluation: +[05:00:14.093] average MSE: 0.07752257032284346 average PSNR: 24.11672520134654 average SSIM: 0.6419967814460245 +[05:00:44.398] iteration 51900 [30.28 sec]: learning rate : 0.000125 loss : 0.410794 +[05:01:34.702] iteration 52000 [80.58 sec]: learning rate : 0.000125 loss : 0.343432 +[05:02:24.712] iteration 52100 [130.59 sec]: learning rate : 0.000125 loss : 0.369311 +[05:03:15.137] iteration 52200 [181.02 sec]: learning rate : 0.000125 loss : 0.326371 +[05:04:05.526] iteration 52300 [231.41 sec]: learning rate : 0.000125 loss : 0.314985 +[05:04:55.658] iteration 52400 [281.54 sec]: learning rate : 0.000125 loss : 0.313197 +[05:05:03.668] Epoch 90 Evaluation: +[05:05:49.556] average MSE: 0.07038813128601391 average PSNR: 24.539836447391384 average SSIM: 0.6404805066985532 +[05:06:31.870] iteration 52500 [42.29 sec]: learning rate : 0.000125 loss : 0.349469 +[05:07:22.041] iteration 52600 [92.46 sec]: learning rate : 0.000125 loss : 0.361883 +[05:08:12.033] iteration 52700 [142.45 sec]: learning rate : 0.000125 loss : 0.430876 +[05:09:02.428] iteration 52800 [192.85 sec]: learning rate : 0.000125 loss : 0.362863 +[05:09:52.442] iteration 52900 [242.86 sec]: learning rate : 0.000125 loss : 0.284712 +[05:10:38.936] Epoch 91 Evaluation: +[05:11:24.336] average MSE: 0.0875434635478839 average PSNR: 23.586016372102208 average SSIM: 0.634648873565529 +[05:11:28.609] iteration 53000 [4.24 sec]: learning rate : 0.000125 loss : 0.338417 +[05:12:18.545] iteration 53100 [54.18 sec]: learning rate : 0.000125 loss : 0.314214 +[05:13:08.643] iteration 53200 [104.27 sec]: learning rate : 0.000125 loss : 0.342530 +[05:13:58.677] iteration 53300 [154.33 sec]: learning rate : 0.000125 loss : 0.381554 +[05:14:49.126] iteration 53400 [204.76 sec]: learning rate : 0.000125 loss : 0.411436 +[05:15:39.156] iteration 53500 [254.79 sec]: learning rate : 0.000125 loss : 0.322660 +[05:16:13.122] Epoch 92 Evaluation: +[05:16:57.567] average MSE: 0.08756428586949418 average PSNR: 23.585773977108065 average SSIM: 0.6355802089478305 +[05:17:13.760] iteration 53600 [16.17 sec]: learning rate : 0.000125 loss : 0.281503 +[05:18:03.813] iteration 53700 [66.22 sec]: learning rate : 0.000125 loss : 0.351021 +[05:18:54.549] iteration 53800 [116.96 sec]: learning rate : 0.000125 loss : 0.404771 +[05:19:44.513] iteration 53900 [166.92 sec]: learning rate : 0.000125 loss : 0.292997 +[05:20:34.561] iteration 54000 [216.97 sec]: learning rate : 0.000125 loss : 0.299475 +[05:21:24.896] iteration 54100 [267.31 sec]: learning rate : 0.000125 loss : 0.285747 +[05:21:46.931] Epoch 93 Evaluation: +[05:22:31.877] average MSE: 0.09881493737667171 average PSNR: 23.062467161186426 average SSIM: 0.6283957534847133 +[05:23:00.045] iteration 54200 [28.14 sec]: learning rate : 0.000125 loss : 0.311693 +[05:23:50.112] iteration 54300 [78.21 sec]: learning rate : 0.000125 loss : 0.246206 +[05:24:40.114] iteration 54400 [128.21 sec]: learning rate : 0.000125 loss : 0.297522 +[05:25:30.028] iteration 54500 [178.12 sec]: learning rate : 0.000125 loss : 0.322012 +[05:26:20.539] iteration 54600 [228.64 sec]: learning rate : 0.000125 loss : 0.365313 +[05:27:10.798] iteration 54700 [278.89 sec]: learning rate : 0.000125 loss : 0.353549 +[05:27:20.895] Epoch 94 Evaluation: +[05:28:07.508] average MSE: 0.09356295743355293 average PSNR: 23.29649497794085 average SSIM: 0.6378830733436317 +[05:28:47.691] iteration 54800 [40.15 sec]: learning rate : 0.000125 loss : 0.374171 +[05:29:37.778] iteration 54900 [90.24 sec]: learning rate : 0.000125 loss : 0.342357 +[05:30:27.723] iteration 55000 [140.19 sec]: learning rate : 0.000125 loss : 0.331881 +[05:31:17.791] iteration 55100 [190.25 sec]: learning rate : 0.000125 loss : 0.346606 +[05:32:07.846] iteration 55200 [240.31 sec]: learning rate : 0.000125 loss : 0.240560 +[05:32:56.218] Epoch 95 Evaluation: +[05:33:43.027] average MSE: 0.10189378852730395 average PSNR: 22.931966721179325 average SSIM: 0.6414167997999348 +[05:33:45.244] iteration 55300 [2.19 sec]: learning rate : 0.000125 loss : 0.259756 +[05:34:35.392] iteration 55400 [52.34 sec]: learning rate : 0.000125 loss : 0.285954 +[05:35:25.401] iteration 55500 [102.35 sec]: learning rate : 0.000125 loss : 0.418047 +[05:36:15.329] iteration 55600 [152.28 sec]: learning rate : 0.000125 loss : 0.246013 +[05:37:05.737] iteration 55700 [202.68 sec]: learning rate : 0.000125 loss : 0.322535 +[05:37:55.719] iteration 55800 [252.67 sec]: learning rate : 0.000125 loss : 0.306653 +[05:38:31.810] Epoch 96 Evaluation: +[05:39:17.897] average MSE: 0.05585292787010946 average PSNR: 25.55271362072208 average SSIM: 0.6701733865266049 +[05:39:32.198] iteration 55900 [14.28 sec]: learning rate : 0.000125 loss : 0.269365 +[05:40:22.338] iteration 56000 [64.42 sec]: learning rate : 0.000125 loss : 0.316992 +[05:41:12.765] iteration 56100 [114.84 sec]: learning rate : 0.000125 loss : 0.318961 +[05:42:02.914] iteration 56200 [164.99 sec]: learning rate : 0.000125 loss : 0.367398 +[05:42:53.101] iteration 56300 [215.18 sec]: learning rate : 0.000125 loss : 0.340480 +[05:43:43.078] iteration 56400 [265.15 sec]: learning rate : 0.000125 loss : 0.346829 +[05:44:07.190] Epoch 97 Evaluation: +[05:44:52.158] average MSE: 0.057037546581925724 average PSNR: 25.462339275349553 average SSIM: 0.6678854812613878 +[05:45:18.934] iteration 56500 [26.75 sec]: learning rate : 0.000125 loss : 0.353074 +[05:46:09.531] iteration 56600 [77.35 sec]: learning rate : 0.000125 loss : 0.307931 +[05:46:59.554] iteration 56700 [127.37 sec]: learning rate : 0.000125 loss : 0.274513 +[05:47:49.683] iteration 56800 [177.50 sec]: learning rate : 0.000125 loss : 0.303837 +[05:48:40.008] iteration 56900 [227.83 sec]: learning rate : 0.000125 loss : 0.338713 +[05:49:30.137] iteration 57000 [277.95 sec]: learning rate : 0.000125 loss : 0.295113 +[05:49:42.158] Epoch 98 Evaluation: +[05:50:28.295] average MSE: 0.08609690340301453 average PSNR: 23.65548693490377 average SSIM: 0.6408786740958768 +[05:51:06.606] iteration 57100 [38.29 sec]: learning rate : 0.000125 loss : 0.306994 +[05:51:56.863] iteration 57200 [88.54 sec]: learning rate : 0.000125 loss : 0.412685 +[05:52:46.921] iteration 57300 [138.60 sec]: learning rate : 0.000125 loss : 0.330332 +[05:53:36.956] iteration 57400 [188.63 sec]: learning rate : 0.000125 loss : 0.286039 +[05:54:26.901] iteration 57500 [238.58 sec]: learning rate : 0.000125 loss : 0.273572 +[05:55:17.327] iteration 57600 [289.01 sec]: learning rate : 0.000125 loss : 0.392283 +[05:55:17.372] Epoch 99 Evaluation: +[05:56:02.574] average MSE: 0.0733075164124867 average PSNR: 24.35840363042069 average SSIM: 0.6453024205309947 +[05:56:52.951] iteration 57700 [50.35 sec]: learning rate : 0.000125 loss : 0.284022 +[05:57:43.389] iteration 57800 [100.79 sec]: learning rate : 0.000125 loss : 0.369725 +[05:58:33.521] iteration 57900 [150.92 sec]: learning rate : 0.000125 loss : 0.262817 +[05:59:23.668] iteration 58000 [201.07 sec]: learning rate : 0.000125 loss : 0.352506 +[06:00:13.697] iteration 58100 [251.10 sec]: learning rate : 0.000125 loss : 0.374614 +[06:00:51.809] Epoch 100 Evaluation: +[06:01:38.328] average MSE: 0.09036821144463328 average PSNR: 23.449829817728727 average SSIM: 0.6393325137015033 +[06:01:50.531] iteration 58200 [12.18 sec]: learning rate : 0.000125 loss : 0.305311 +[06:02:40.439] iteration 58300 [62.08 sec]: learning rate : 0.000125 loss : 0.354027 +[06:03:32.223] iteration 58400 [113.87 sec]: learning rate : 0.000125 loss : 0.281681 +[06:04:22.328] iteration 58500 [163.97 sec]: learning rate : 0.000125 loss : 0.317355 +[06:05:12.331] iteration 58600 [213.98 sec]: learning rate : 0.000125 loss : 0.265304 +[06:06:02.449] iteration 58700 [264.10 sec]: learning rate : 0.000125 loss : 0.369561 +[06:06:28.470] Epoch 101 Evaluation: +[06:07:13.166] average MSE: 0.07986426915215603 average PSNR: 23.984260576682694 average SSIM: 0.6430092603039274 +[06:07:37.559] iteration 58800 [24.37 sec]: learning rate : 0.000125 loss : 0.275870 +[06:08:27.545] iteration 58900 [74.35 sec]: learning rate : 0.000125 loss : 0.310047 +[06:09:18.028] iteration 59000 [124.86 sec]: learning rate : 0.000125 loss : 0.275605 +[06:10:08.449] iteration 59100 [175.26 sec]: learning rate : 0.000125 loss : 0.350469 +[06:10:58.916] iteration 59200 [225.72 sec]: learning rate : 0.000125 loss : 0.373759 +[06:11:48.964] iteration 59300 [275.77 sec]: learning rate : 0.000125 loss : 0.357522 +[06:12:02.977] Epoch 102 Evaluation: +[06:12:48.711] average MSE: 0.07620527970698478 average PSNR: 24.192022512636793 average SSIM: 0.6385191843430784 +[06:13:24.917] iteration 59400 [36.18 sec]: learning rate : 0.000125 loss : 0.226876 +[06:14:14.991] iteration 59500 [86.25 sec]: learning rate : 0.000125 loss : 0.414190 +[06:15:05.114] iteration 59600 [136.38 sec]: learning rate : 0.000125 loss : 0.379765 +[06:15:55.450] iteration 59700 [186.71 sec]: learning rate : 0.000125 loss : 0.300407 +[06:16:45.585] iteration 59800 [236.85 sec]: learning rate : 0.000125 loss : 0.321137 +[06:17:35.645] iteration 59900 [286.91 sec]: learning rate : 0.000125 loss : 0.393317 +[06:17:37.655] Epoch 103 Evaluation: +[06:18:25.296] average MSE: 0.08049427843465054 average PSNR: 23.951797393874283 average SSIM: 0.639138472680051 +[06:19:13.444] iteration 60000 [48.12 sec]: learning rate : 0.000031 loss : 0.310206 +[06:19:13.603] save model to model/FSMNet_m4raw_4x_lr5e-4_t5_new_kspace_time/iter_60000.pth +[06:20:03.647] iteration 60100 [98.32 sec]: learning rate : 0.000063 loss : 0.347546 +[06:20:53.664] iteration 60200 [148.34 sec]: learning rate : 0.000063 loss : 0.249237 +[06:21:44.432] iteration 60300 [199.11 sec]: learning rate : 0.000063 loss : 0.318537 +[06:22:34.462] iteration 60400 [249.14 sec]: learning rate : 0.000063 loss : 0.292461 +[06:23:14.408] Epoch 104 Evaluation: +[06:23:59.608] average MSE: 0.10360416095421522 average PSNR: 22.859101539286872 average SSIM: 0.6383643648738418 +[06:24:09.995] iteration 60500 [10.36 sec]: learning rate : 0.000063 loss : 0.369032 +[06:24:59.917] iteration 60600 [60.28 sec]: learning rate : 0.000063 loss : 0.349249 +[06:25:50.321] iteration 60700 [110.69 sec]: learning rate : 0.000063 loss : 0.333260 +[06:26:40.240] iteration 60800 [160.60 sec]: learning rate : 0.000063 loss : 0.282598 +[06:27:30.573] iteration 60900 [210.94 sec]: learning rate : 0.000063 loss : 0.291143 +[06:28:20.602] iteration 61000 [260.97 sec]: learning rate : 0.000063 loss : 0.352893 +[06:28:48.586] Epoch 105 Evaluation: +[06:29:33.766] average MSE: 0.09323514138945001 average PSNR: 23.310365632417433 average SSIM: 0.6373132313901523 +[06:29:55.943] iteration 61100 [22.15 sec]: learning rate : 0.000063 loss : 0.372553 +[06:30:46.378] iteration 61200 [72.59 sec]: learning rate : 0.000063 loss : 0.364996 +[06:31:36.413] iteration 61300 [122.62 sec]: learning rate : 0.000063 loss : 0.294861 +[06:32:26.341] iteration 61400 [172.55 sec]: learning rate : 0.000063 loss : 0.287071 +[06:33:17.007] iteration 61500 [223.22 sec]: learning rate : 0.000063 loss : 0.374940 +[06:34:07.408] iteration 61600 [273.62 sec]: learning rate : 0.000063 loss : 0.310548 +[06:34:23.399] Epoch 106 Evaluation: +[06:35:08.343] average MSE: 0.0917759709958827 average PSNR: 23.378923773266152 average SSIM: 0.6361971911590917 +[06:35:42.513] iteration 61700 [34.14 sec]: learning rate : 0.000063 loss : 0.383219 +[06:36:32.601] iteration 61800 [84.23 sec]: learning rate : 0.000063 loss : 0.333880 +[06:37:22.623] iteration 61900 [134.25 sec]: learning rate : 0.000063 loss : 0.290267 +[06:38:12.568] iteration 62000 [184.20 sec]: learning rate : 0.000063 loss : 0.323121 +[06:39:03.081] iteration 62100 [234.71 sec]: learning rate : 0.000063 loss : 0.353817 +[06:39:53.627] iteration 62200 [285.26 sec]: learning rate : 0.000063 loss : 0.268113 +[06:39:57.757] Epoch 107 Evaluation: +[06:40:43.467] average MSE: 0.06356432034842427 average PSNR: 24.982327545008427 average SSIM: 0.6642107705227743 +[06:41:30.001] iteration 62300 [46.51 sec]: learning rate : 0.000063 loss : 0.295520 +[06:42:20.094] iteration 62400 [96.60 sec]: learning rate : 0.000063 loss : 0.385697 +[06:43:10.027] iteration 62500 [146.53 sec]: learning rate : 0.000063 loss : 0.388215 +[06:44:00.070] iteration 62600 [196.58 sec]: learning rate : 0.000063 loss : 0.269823 +[06:44:50.094] iteration 62700 [246.60 sec]: learning rate : 0.000063 loss : 0.379035 +[06:45:32.460] Epoch 108 Evaluation: +[06:46:17.930] average MSE: 0.09370399826658468 average PSNR: 23.292931641462804 average SSIM: 0.643785816707966 +[06:46:26.132] iteration 62800 [8.18 sec]: learning rate : 0.000063 loss : 0.355283 +[06:47:16.201] iteration 62900 [58.25 sec]: learning rate : 0.000063 loss : 0.391930 +[06:48:06.682] iteration 63000 [108.73 sec]: learning rate : 0.000063 loss : 0.325985 +[06:48:57.241] iteration 63100 [159.29 sec]: learning rate : 0.000063 loss : 0.296474 +[06:49:47.311] iteration 63200 [209.36 sec]: learning rate : 0.000063 loss : 0.317409 +[06:50:37.357] iteration 63300 [259.40 sec]: learning rate : 0.000063 loss : 0.271566 +[06:51:07.337] Epoch 109 Evaluation: +[06:51:53.498] average MSE: 0.077601424637257 average PSNR: 24.109087007285197 average SSIM: 0.6485313547144921 +[06:52:13.790] iteration 63400 [20.27 sec]: learning rate : 0.000063 loss : 0.279190 +[06:53:03.933] iteration 63500 [70.41 sec]: learning rate : 0.000063 loss : 0.367769 +[06:53:53.902] iteration 63600 [120.38 sec]: learning rate : 0.000063 loss : 0.246307 +[06:54:43.930] iteration 63700 [170.41 sec]: learning rate : 0.000063 loss : 0.395176 +[06:55:33.987] iteration 63800 [220.46 sec]: learning rate : 0.000063 loss : 0.256918 +[06:56:24.039] iteration 63900 [270.52 sec]: learning rate : 0.000063 loss : 0.330646 +[06:56:42.141] Epoch 110 Evaluation: +[06:57:29.799] average MSE: 0.11152041871732687 average PSNR: 22.546229191892433 average SSIM: 0.6404362460398262 +[06:58:02.342] iteration 64000 [32.52 sec]: learning rate : 0.000063 loss : 0.313377 +[06:58:52.475] iteration 64100 [82.65 sec]: learning rate : 0.000063 loss : 0.257505 +[06:59:42.473] iteration 64200 [132.65 sec]: learning rate : 0.000063 loss : 0.335420 +[07:00:32.590] iteration 64300 [182.76 sec]: learning rate : 0.000063 loss : 0.360957 +[07:01:22.673] iteration 64400 [232.85 sec]: learning rate : 0.000063 loss : 0.360356 +[07:02:12.646] iteration 64500 [282.82 sec]: learning rate : 0.000063 loss : 0.309733 +[07:02:18.654] Epoch 111 Evaluation: +[07:03:06.520] average MSE: 0.07473804537684797 average PSNR: 24.274543083082435 average SSIM: 0.6469288690565689 +[07:03:51.625] iteration 64600 [45.08 sec]: learning rate : 0.000063 loss : 0.350162 +[07:04:41.662] iteration 64700 [95.12 sec]: learning rate : 0.000063 loss : 0.378084 +[07:05:31.581] iteration 64800 [145.04 sec]: learning rate : 0.000063 loss : 0.402906 +[07:06:21.986] iteration 64900 [195.44 sec]: learning rate : 0.000063 loss : 0.276191 +[07:07:12.057] iteration 65000 [245.51 sec]: learning rate : 0.000063 loss : 0.336250 +[07:07:56.037] Epoch 112 Evaluation: +[07:08:41.626] average MSE: 0.08954816347688195 average PSNR: 23.486448086897337 average SSIM: 0.6370379770247723 +[07:08:47.834] iteration 65100 [6.18 sec]: learning rate : 0.000063 loss : 0.326900 +[07:09:38.330] iteration 65200 [56.70 sec]: learning rate : 0.000063 loss : 0.247968 +[07:10:28.321] iteration 65300 [106.67 sec]: learning rate : 0.000063 loss : 0.287005 +[07:11:18.801] iteration 65400 [157.15 sec]: learning rate : 0.000063 loss : 0.400246 +[07:12:08.827] iteration 65500 [207.17 sec]: learning rate : 0.000063 loss : 0.357386 +[07:12:58.750] iteration 65600 [257.10 sec]: learning rate : 0.000063 loss : 0.302774 +[07:13:30.810] Epoch 113 Evaluation: +[07:14:15.905] average MSE: 0.09363167283420275 average PSNR: 23.292271888532138 average SSIM: 0.6387706972708069 +[07:14:34.092] iteration 65700 [18.16 sec]: learning rate : 0.000063 loss : 0.336246 +[07:15:24.638] iteration 65800 [68.71 sec]: learning rate : 0.000063 loss : 0.304618 +[07:16:15.047] iteration 65900 [119.12 sec]: learning rate : 0.000063 loss : 0.294047 +[07:17:05.196] iteration 66000 [169.27 sec]: learning rate : 0.000063 loss : 0.250647 +[07:17:55.338] iteration 66100 [219.41 sec]: learning rate : 0.000063 loss : 0.367206 +[07:18:45.843] iteration 66200 [269.91 sec]: learning rate : 0.000063 loss : 0.378439 +[07:19:05.833] Epoch 114 Evaluation: +[07:19:52.084] average MSE: 0.10021897488423386 average PSNR: 22.998373880143472 average SSIM: 0.6354192521783403 +[07:20:22.516] iteration 66300 [30.41 sec]: learning rate : 0.000063 loss : 0.368713 +[07:21:12.503] iteration 66400 [80.40 sec]: learning rate : 0.000063 loss : 0.347396 +[07:22:02.762] iteration 66500 [130.65 sec]: learning rate : 0.000063 loss : 0.371522 +[07:22:52.880] iteration 66600 [180.77 sec]: learning rate : 0.000063 loss : 0.368923 +[07:23:43.240] iteration 66700 [231.13 sec]: learning rate : 0.000063 loss : 0.342131 +[07:24:33.329] iteration 66800 [281.22 sec]: learning rate : 0.000063 loss : 0.350700 +[07:24:41.321] Epoch 115 Evaluation: +[07:25:26.664] average MSE: 0.07644946511929454 average PSNR: 24.17690187633899 average SSIM: 0.6497639238145558 +[07:26:09.500] iteration 66900 [42.81 sec]: learning rate : 0.000063 loss : 0.426400 +[07:26:59.474] iteration 67000 [92.78 sec]: learning rate : 0.000063 loss : 0.346897 +[07:27:49.955] iteration 67100 [143.26 sec]: learning rate : 0.000063 loss : 0.410494 +[07:28:39.993] iteration 67200 [193.30 sec]: learning rate : 0.000063 loss : 0.324962 +[07:29:30.048] iteration 67300 [243.36 sec]: learning rate : 0.000063 loss : 0.296738 +[07:30:16.156] Epoch 116 Evaluation: +[07:31:02.666] average MSE: 0.07291973632883084 average PSNR: 24.38195023996 average SSIM: 0.6544937519129405 +[07:31:06.889] iteration 67400 [4.20 sec]: learning rate : 0.000063 loss : 0.341656 +[07:31:56.813] iteration 67500 [54.12 sec]: learning rate : 0.000063 loss : 0.354875 +[07:32:47.461] iteration 67600 [104.77 sec]: learning rate : 0.000063 loss : 0.332537 +[07:33:37.853] iteration 67700 [155.16 sec]: learning rate : 0.000063 loss : 0.339293 +[07:34:28.156] iteration 67800 [205.46 sec]: learning rate : 0.000063 loss : 0.465392 +[07:35:18.238] iteration 67900 [255.55 sec]: learning rate : 0.000063 loss : 0.281383 +[07:35:52.215] Epoch 117 Evaluation: +[07:36:38.991] average MSE: 0.09107986750280858 average PSNR: 23.413647201119293 average SSIM: 0.63603987213402 +[07:36:55.363] iteration 68000 [16.34 sec]: learning rate : 0.000063 loss : 0.328532 +[07:37:45.323] iteration 68100 [66.30 sec]: learning rate : 0.000063 loss : 0.339819 +[07:38:35.407] iteration 68200 [116.39 sec]: learning rate : 0.000063 loss : 0.269414 +[07:39:25.370] iteration 68300 [166.35 sec]: learning rate : 0.000063 loss : 0.329131 +[07:40:15.861] iteration 68400 [216.84 sec]: learning rate : 0.000063 loss : 0.332459 +[07:41:06.256] iteration 68500 [267.23 sec]: learning rate : 0.000063 loss : 0.323236 +[07:41:28.265] Epoch 118 Evaluation: +[07:42:14.040] average MSE: 0.09736859168814449 average PSNR: 23.12410885469358 average SSIM: 0.6358730281347239 +[07:42:42.224] iteration 68600 [28.16 sec]: learning rate : 0.000063 loss : 0.288030 +[07:43:32.347] iteration 68700 [78.28 sec]: learning rate : 0.000063 loss : 0.248398 +[07:44:22.399] iteration 68800 [128.33 sec]: learning rate : 0.000063 loss : 0.326419 +[07:45:12.352] iteration 68900 [178.29 sec]: learning rate : 0.000063 loss : 0.344148 +[07:46:02.725] iteration 69000 [228.66 sec]: learning rate : 0.000063 loss : 0.391684 +[07:46:52.771] iteration 69100 [278.71 sec]: learning rate : 0.000063 loss : 0.352266 +[07:47:02.763] Epoch 119 Evaluation: +[07:47:47.912] average MSE: 0.10189676582026433 average PSNR: 22.930128177139363 average SSIM: 0.6426741418246052 +[07:48:28.575] iteration 69200 [40.63 sec]: learning rate : 0.000063 loss : 0.389532 +[07:49:18.684] iteration 69300 [90.74 sec]: learning rate : 0.000063 loss : 0.392722 +[07:50:08.706] iteration 69400 [140.76 sec]: learning rate : 0.000063 loss : 0.289427 +[07:50:58.988] iteration 69500 [191.05 sec]: learning rate : 0.000063 loss : 0.371884 +[07:51:49.007] iteration 69600 [241.07 sec]: learning rate : 0.000063 loss : 0.266705 +[07:52:37.432] Epoch 120 Evaluation: +[07:53:24.072] average MSE: 0.07473385081154181 average PSNR: 24.276340939441365 average SSIM: 0.6487144916751062 +[07:53:26.290] iteration 69700 [2.19 sec]: learning rate : 0.000063 loss : 0.257887 +[07:54:16.368] iteration 69800 [52.27 sec]: learning rate : 0.000063 loss : 0.294886 +[07:55:06.386] iteration 69900 [102.29 sec]: learning rate : 0.000063 loss : 0.402076 +[07:55:56.870] iteration 70000 [152.77 sec]: learning rate : 0.000063 loss : 0.272078 +[07:56:46.912] iteration 70100 [202.81 sec]: learning rate : 0.000063 loss : 0.300497 +[07:57:37.025] iteration 70200 [252.93 sec]: learning rate : 0.000063 loss : 0.307721 +[07:58:13.292] Epoch 121 Evaluation: +[07:58:57.980] average MSE: 0.08116476457617254 average PSNR: 23.913134660603244 average SSIM: 0.6489293824707506 +[07:59:12.184] iteration 70300 [14.18 sec]: learning rate : 0.000063 loss : 0.295300 +[08:00:02.662] iteration 70400 [64.66 sec]: learning rate : 0.000063 loss : 0.351269 +[08:00:52.740] iteration 70500 [114.74 sec]: learning rate : 0.000063 loss : 0.315232 +[08:01:42.692] iteration 70600 [164.69 sec]: learning rate : 0.000063 loss : 0.399963 +[08:02:32.757] iteration 70700 [214.75 sec]: learning rate : 0.000063 loss : 0.332647 +[08:03:23.113] iteration 70800 [265.11 sec]: learning rate : 0.000063 loss : 0.352791 +[08:03:47.181] Epoch 122 Evaluation: +[08:04:34.487] average MSE: 0.08044428573544402 average PSNR: 23.952265487472477 average SSIM: 0.6463098248836898 +[08:05:00.726] iteration 70900 [26.21 sec]: learning rate : 0.000063 loss : 0.341998 +[08:05:50.861] iteration 71000 [76.35 sec]: learning rate : 0.000063 loss : 0.346473 +[08:06:40.828] iteration 71100 [126.31 sec]: learning rate : 0.000063 loss : 0.293418 +[08:07:30.869] iteration 71200 [176.36 sec]: learning rate : 0.000063 loss : 0.292766 +[08:08:21.301] iteration 71300 [226.79 sec]: learning rate : 0.000063 loss : 0.331227 +[08:09:11.259] iteration 71400 [276.75 sec]: learning rate : 0.000063 loss : 0.277675 +[08:09:23.253] Epoch 123 Evaluation: +[08:10:09.312] average MSE: 0.10242570325152545 average PSNR: 22.908222938554573 average SSIM: 0.6418538813521503 +[08:10:47.846] iteration 71500 [38.51 sec]: learning rate : 0.000063 loss : 0.293909 +[08:11:37.950] iteration 71600 [88.61 sec]: learning rate : 0.000063 loss : 0.344035 +[08:12:27.893] iteration 71700 [138.55 sec]: learning rate : 0.000063 loss : 0.355991 +[08:13:17.927] iteration 71800 [188.59 sec]: learning rate : 0.000063 loss : 0.258182 +[08:14:07.865] iteration 71900 [238.53 sec]: learning rate : 0.000063 loss : 0.300850 +[08:14:57.871] iteration 72000 [288.53 sec]: learning rate : 0.000063 loss : 0.329346 +[08:14:57.920] Epoch 124 Evaluation: +[08:15:43.507] average MSE: 0.11240452565862444 average PSNR: 22.517787168959206 average SSIM: 0.652164270972159 +[08:16:34.400] iteration 72100 [50.87 sec]: learning rate : 0.000063 loss : 0.314501 +[08:17:24.909] iteration 72200 [101.38 sec]: learning rate : 0.000063 loss : 0.339629 +[08:18:15.236] iteration 72300 [151.73 sec]: learning rate : 0.000063 loss : 0.279141 +[08:19:05.367] iteration 72400 [201.83 sec]: learning rate : 0.000063 loss : 0.387892 +[08:19:55.330] iteration 72500 [251.80 sec]: learning rate : 0.000063 loss : 0.332430 +[08:20:33.407] Epoch 125 Evaluation: +[08:21:17.796] average MSE: 0.08897073024368234 average PSNR: 23.5141470878895 average SSIM: 0.6448932617620894 +[08:21:30.054] iteration 72600 [12.23 sec]: learning rate : 0.000063 loss : 0.324165 +[08:22:20.519] iteration 72700 [62.70 sec]: learning rate : 0.000063 loss : 0.289032 +[08:23:10.511] iteration 72800 [112.69 sec]: learning rate : 0.000063 loss : 0.322988 +[08:24:00.537] iteration 72900 [162.71 sec]: learning rate : 0.000063 loss : 0.296840 +[08:24:50.607] iteration 73000 [212.79 sec]: learning rate : 0.000063 loss : 0.296525 +[08:25:40.698] iteration 73100 [262.90 sec]: learning rate : 0.000063 loss : 0.397359 +[08:26:07.013] Epoch 126 Evaluation: +[08:26:52.386] average MSE: 0.10902931372990826 average PSNR: 22.645302090604968 average SSIM: 0.6493734568671198 +[08:27:16.705] iteration 73200 [24.29 sec]: learning rate : 0.000063 loss : 0.284421 +[08:28:06.613] iteration 73300 [74.20 sec]: learning rate : 0.000063 loss : 0.290895 +[08:28:56.995] iteration 73400 [124.58 sec]: learning rate : 0.000063 loss : 0.293739 +[08:29:47.089] iteration 73500 [174.68 sec]: learning rate : 0.000063 loss : 0.347372 +[08:30:37.077] iteration 73600 [224.66 sec]: learning rate : 0.000063 loss : 0.346848 +[08:31:27.184] iteration 73700 [274.77 sec]: learning rate : 0.000063 loss : 0.348344 +[08:31:41.190] Epoch 127 Evaluation: +[08:32:27.434] average MSE: 0.10213845566558043 average PSNR: 22.918064360299223 average SSIM: 0.6420718924360815 +[08:33:03.842] iteration 73800 [36.38 sec]: learning rate : 0.000063 loss : 0.230072 +[08:33:54.421] iteration 73900 [86.96 sec]: learning rate : 0.000063 loss : 0.335126 +[08:34:45.099] iteration 74000 [137.64 sec]: learning rate : 0.000063 loss : 0.354729 +[08:35:35.659] iteration 74100 [188.20 sec]: learning rate : 0.000063 loss : 0.261323 +[08:36:25.589] iteration 74200 [238.13 sec]: learning rate : 0.000063 loss : 0.300807 +[08:37:15.619] iteration 74300 [288.16 sec]: learning rate : 0.000063 loss : 0.449836 +[08:37:17.628] Epoch 128 Evaluation: +[08:38:04.757] average MSE: 0.08172072234233113 average PSNR: 23.88459921536335 average SSIM: 0.6440881050250656 +[08:38:52.918] iteration 74400 [48.13 sec]: learning rate : 0.000063 loss : 0.356299 +[08:39:42.982] iteration 74500 [98.20 sec]: learning rate : 0.000063 loss : 0.320566 +[08:40:33.595] iteration 74600 [148.81 sec]: learning rate : 0.000063 loss : 0.249682 +[08:41:24.228] iteration 74700 [199.44 sec]: learning rate : 0.000063 loss : 0.321638 +[08:42:14.358] iteration 74800 [249.57 sec]: learning rate : 0.000063 loss : 0.301742 +[08:42:54.373] Epoch 129 Evaluation: +[08:43:41.078] average MSE: 0.09956519424702058 average PSNR: 23.030374520130618 average SSIM: 0.642505525298715 +[08:43:51.442] iteration 74900 [10.34 sec]: learning rate : 0.000063 loss : 0.325147 +[08:44:41.482] iteration 75000 [60.38 sec]: learning rate : 0.000063 loss : 0.312025 +[08:45:31.503] iteration 75100 [110.40 sec]: learning rate : 0.000063 loss : 0.304097 +[08:46:21.523] iteration 75200 [160.42 sec]: learning rate : 0.000063 loss : 0.266343 +[08:47:12.004] iteration 75300 [210.90 sec]: learning rate : 0.000063 loss : 0.340223 +[08:48:02.062] iteration 75400 [260.96 sec]: learning rate : 0.000063 loss : 0.311826 +[08:48:30.578] Epoch 130 Evaluation: +[08:49:17.344] average MSE: 0.11478494355280688 average PSNR: 22.429533696743263 average SSIM: 0.6470799662738361 +[08:49:39.576] iteration 75500 [22.21 sec]: learning rate : 0.000063 loss : 0.327938 +[08:50:29.733] iteration 75600 [72.36 sec]: learning rate : 0.000063 loss : 0.336143 +[08:51:19.840] iteration 75700 [122.47 sec]: learning rate : 0.000063 loss : 0.317359 +[08:52:09.849] iteration 75800 [172.48 sec]: learning rate : 0.000063 loss : 0.278954 +[08:53:00.954] iteration 75900 [223.58 sec]: learning rate : 0.000063 loss : 0.403755 +[08:53:51.057] iteration 76000 [273.69 sec]: learning rate : 0.000063 loss : 0.316800 +[08:54:07.075] Epoch 131 Evaluation: +[08:54:53.258] average MSE: 0.09996455454103644 average PSNR: 23.014763265753366 average SSIM: 0.6438218223337822 +[08:55:27.454] iteration 76100 [34.17 sec]: learning rate : 0.000063 loss : 0.344702 +[08:56:17.933] iteration 76200 [84.65 sec]: learning rate : 0.000063 loss : 0.311575 +[08:57:08.043] iteration 76300 [134.76 sec]: learning rate : 0.000063 loss : 0.296349 +[08:57:58.041] iteration 76400 [184.76 sec]: learning rate : 0.000063 loss : 0.357152 +[08:58:48.627] iteration 76500 [235.34 sec]: learning rate : 0.000063 loss : 0.331319 +[08:59:38.626] iteration 76600 [285.34 sec]: learning rate : 0.000063 loss : 0.304983 +[08:59:42.632] Epoch 132 Evaluation: +[09:00:27.741] average MSE: 0.12003570428872172 average PSNR: 22.24619152558472 average SSIM: 0.650029446408974 +[09:01:14.042] iteration 76700 [46.28 sec]: learning rate : 0.000063 loss : 0.304184 +[09:02:04.513] iteration 76800 [96.75 sec]: learning rate : 0.000063 loss : 0.381065 +[09:02:54.423] iteration 76900 [146.66 sec]: learning rate : 0.000063 loss : 0.320926 +[09:03:44.830] iteration 77000 [197.06 sec]: learning rate : 0.000063 loss : 0.302623 +[09:04:35.198] iteration 77100 [247.43 sec]: learning rate : 0.000063 loss : 0.369670 +[09:05:17.161] Epoch 133 Evaluation: +[09:06:03.049] average MSE: 0.1109567968420502 average PSNR: 22.57196513184964 average SSIM: 0.6530721562949526 +[09:06:11.297] iteration 77200 [8.22 sec]: learning rate : 0.000063 loss : 0.305682 +[09:07:01.408] iteration 77300 [58.33 sec]: learning rate : 0.000063 loss : 0.340832 +[09:07:51.451] iteration 77400 [108.37 sec]: learning rate : 0.000063 loss : 0.303323 +[09:08:41.377] iteration 77500 [158.30 sec]: learning rate : 0.000063 loss : 0.268641 +[09:09:31.411] iteration 77600 [208.33 sec]: learning rate : 0.000063 loss : 0.319349 +[09:10:21.399] iteration 77700 [258.32 sec]: learning rate : 0.000063 loss : 0.297153 +[09:10:52.699] Epoch 134 Evaluation: +[09:11:37.117] average MSE: 0.09579916106834481 average PSNR: 23.194311811029394 average SSIM: 0.6447355314085821 +[09:11:57.402] iteration 77800 [20.26 sec]: learning rate : 0.000063 loss : 0.295930 +[09:12:47.533] iteration 77900 [70.39 sec]: learning rate : 0.000063 loss : 0.333296 +[09:13:37.556] iteration 78000 [120.42 sec]: learning rate : 0.000063 loss : 0.252205 +[09:14:27.694] iteration 78100 [170.55 sec]: learning rate : 0.000063 loss : 0.388554 +[09:15:17.754] iteration 78200 [220.61 sec]: learning rate : 0.000063 loss : 0.278317 +[09:16:07.667] iteration 78300 [270.52 sec]: learning rate : 0.000063 loss : 0.329698 +[09:16:25.710] Epoch 135 Evaluation: +[09:17:10.983] average MSE: 0.09111504848080404 average PSNR: 23.411943059818967 average SSIM: 0.6488732245841643 +[09:17:43.161] iteration 78400 [32.15 sec]: learning rate : 0.000063 loss : 0.291087 +[09:18:33.625] iteration 78500 [82.62 sec]: learning rate : 0.000063 loss : 0.272169 +[09:19:23.568] iteration 78600 [132.56 sec]: learning rate : 0.000063 loss : 0.352359 +[09:20:14.122] iteration 78700 [183.11 sec]: learning rate : 0.000063 loss : 0.348566 +[09:21:04.160] iteration 78800 [233.18 sec]: learning rate : 0.000063 loss : 0.332028 +[09:21:54.101] iteration 78900 [283.09 sec]: learning rate : 0.000063 loss : 0.333809 +[09:22:00.112] Epoch 136 Evaluation: +[09:22:45.752] average MSE: 0.11031844332091956 average PSNR: 22.595096233767563 average SSIM: 0.6460647946072938 +[09:23:30.456] iteration 79000 [44.68 sec]: learning rate : 0.000063 loss : 0.404969 +[09:24:20.404] iteration 79100 [94.63 sec]: learning rate : 0.000063 loss : 0.265309 +[09:25:10.473] iteration 79200 [144.69 sec]: learning rate : 0.000063 loss : 0.411384 +[09:26:00.991] iteration 79300 [195.21 sec]: learning rate : 0.000063 loss : 0.299720 +[09:26:50.925] iteration 79400 [245.15 sec]: learning rate : 0.000063 loss : 0.370239 +[09:27:34.959] Epoch 137 Evaluation: +[09:28:20.214] average MSE: 0.0991914329796342 average PSNR: 23.050162431385875 average SSIM: 0.6461887929902471 +[09:28:26.418] iteration 79500 [6.18 sec]: learning rate : 0.000063 loss : 0.296698 +[09:29:17.154] iteration 79600 [56.91 sec]: learning rate : 0.000063 loss : 0.280756 +[09:30:07.074] iteration 79700 [106.83 sec]: learning rate : 0.000063 loss : 0.310959 +[09:30:57.073] iteration 79800 [156.83 sec]: learning rate : 0.000063 loss : 0.358679 +[09:31:47.129] iteration 79900 [206.89 sec]: learning rate : 0.000063 loss : 0.297643 +[09:32:37.095] iteration 80000 [256.85 sec]: learning rate : 0.000016 loss : 0.336688 +[09:32:37.254] save model to model/FSMNet_m4raw_4x_lr5e-4_t5_new_kspace_time/iter_80000.pth +[09:33:09.350] Epoch 138 Evaluation: +[09:33:56.380] average MSE: 0.10937545744935996 average PSNR: 22.633234808948618 average SSIM: 0.6495248439656719 +[09:34:14.584] iteration 80100 [18.18 sec]: learning rate : 0.000031 loss : 0.297361 +[09:35:05.069] iteration 80200 [68.66 sec]: learning rate : 0.000031 loss : 0.312605 +[09:35:55.064] iteration 80300 [118.66 sec]: learning rate : 0.000031 loss : 0.305165 +[09:36:45.177] iteration 80400 [168.77 sec]: learning rate : 0.000031 loss : 0.251069 +[09:37:35.399] iteration 80500 [218.99 sec]: learning rate : 0.000031 loss : 0.373480 +[09:38:25.536] iteration 80600 [269.13 sec]: learning rate : 0.000031 loss : 0.327118 +[09:38:45.541] Epoch 139 Evaluation: +[09:39:31.930] average MSE: 0.08895757699794073 average PSNR: 23.51582625328265 average SSIM: 0.646485874877862 +[09:40:02.263] iteration 80700 [30.31 sec]: learning rate : 0.000031 loss : 0.331918 +[09:40:52.729] iteration 80800 [80.77 sec]: learning rate : 0.000031 loss : 0.303358 +[09:41:43.119] iteration 80900 [131.16 sec]: learning rate : 0.000031 loss : 0.313055 +[09:42:33.160] iteration 81000 [181.20 sec]: learning rate : 0.000031 loss : 0.333385 +[09:43:23.086] iteration 81100 [231.13 sec]: learning rate : 0.000031 loss : 0.300353 +[09:44:13.131] iteration 81200 [281.17 sec]: learning rate : 0.000031 loss : 0.315816 +[09:44:21.121] Epoch 140 Evaluation: +[09:45:05.965] average MSE: 0.10237085160303963 average PSNR: 22.914289554943633 average SSIM: 0.6495009828040423 +[09:45:48.321] iteration 81300 [42.33 sec]: learning rate : 0.000031 loss : 0.362565 +[09:46:38.804] iteration 81400 [92.81 sec]: learning rate : 0.000031 loss : 0.338594 +[09:47:29.324] iteration 81500 [143.33 sec]: learning rate : 0.000031 loss : 0.394509 +[09:48:19.673] iteration 81600 [193.68 sec]: learning rate : 0.000031 loss : 0.354244 +[09:49:09.728] iteration 81700 [243.74 sec]: learning rate : 0.000031 loss : 0.274425 +[09:49:55.810] Epoch 141 Evaluation: +[09:50:40.771] average MSE: 0.10618730961010145 average PSNR: 22.757728205558458 average SSIM: 0.6481462946591333 +[09:50:44.983] iteration 81800 [4.19 sec]: learning rate : 0.000031 loss : 0.340117 +[09:51:35.002] iteration 81900 [54.21 sec]: learning rate : 0.000031 loss : 0.374909 +[09:52:25.149] iteration 82000 [104.35 sec]: learning rate : 0.000031 loss : 0.373900 +[09:53:15.492] iteration 82100 [154.70 sec]: learning rate : 0.000031 loss : 0.279652 +[09:54:05.605] iteration 82200 [204.81 sec]: learning rate : 0.000031 loss : 0.464209 +[09:54:55.723] iteration 82300 [254.93 sec]: learning rate : 0.000031 loss : 0.281232 +[09:55:30.268] Epoch 142 Evaluation: +[09:56:17.173] average MSE: 0.10062736706814589 average PSNR: 22.984735912665126 average SSIM: 0.6468227702233149 +[09:56:33.403] iteration 82400 [16.20 sec]: learning rate : 0.000031 loss : 0.289340 +[09:57:23.518] iteration 82500 [66.32 sec]: learning rate : 0.000031 loss : 0.321500 +[09:58:13.528] iteration 82600 [116.33 sec]: learning rate : 0.000031 loss : 0.290599 +[09:59:03.760] iteration 82700 [166.56 sec]: learning rate : 0.000031 loss : 0.328682 +[09:59:53.822] iteration 82800 [216.62 sec]: learning rate : 0.000031 loss : 0.276671 +[10:00:43.812] iteration 82900 [266.61 sec]: learning rate : 0.000031 loss : 0.329898 +[10:01:05.806] Epoch 143 Evaluation: +[10:01:50.571] average MSE: 0.1039991279242021 average PSNR: 22.842573233782492 average SSIM: 0.6452468964309928 +[10:02:18.730] iteration 83000 [28.13 sec]: learning rate : 0.000031 loss : 0.269680 +[10:03:08.782] iteration 83100 [78.18 sec]: learning rate : 0.000031 loss : 0.262336 +[10:03:59.335] iteration 83200 [128.74 sec]: learning rate : 0.000031 loss : 0.342951 +[10:04:49.525] iteration 83300 [178.93 sec]: learning rate : 0.000031 loss : 0.336256 +[10:05:39.810] iteration 83400 [229.21 sec]: learning rate : 0.000031 loss : 0.379999 +[10:06:29.837] iteration 83500 [279.24 sec]: learning rate : 0.000031 loss : 0.356296 +[10:06:39.834] Epoch 144 Evaluation: +[10:07:24.444] average MSE: 0.09705926109817588 average PSNR: 23.140065639855912 average SSIM: 0.641467007554197 +[10:08:04.607] iteration 83600 [40.14 sec]: learning rate : 0.000031 loss : 0.360128 +[10:08:54.661] iteration 83700 [90.19 sec]: learning rate : 0.000031 loss : 0.312830 +[10:09:44.629] iteration 83800 [140.16 sec]: learning rate : 0.000031 loss : 0.324653 +[10:10:34.818] iteration 83900 [190.35 sec]: learning rate : 0.000031 loss : 0.379482 +[10:11:25.719] iteration 84000 [241.25 sec]: learning rate : 0.000031 loss : 0.297825 +[10:12:13.655] Epoch 145 Evaluation: +[10:12:58.074] average MSE: 0.07701365672782674 average PSNR: 24.142950923945037 average SSIM: 0.647683361362526 +[10:13:00.286] iteration 84100 [2.19 sec]: learning rate : 0.000031 loss : 0.304218 +[10:13:50.578] iteration 84200 [52.48 sec]: learning rate : 0.000031 loss : 0.331468 +[10:14:40.627] iteration 84300 [102.53 sec]: learning rate : 0.000031 loss : 0.400133 +[10:15:30.557] iteration 84400 [152.46 sec]: learning rate : 0.000031 loss : 0.226834 +[10:16:20.570] iteration 84500 [202.47 sec]: learning rate : 0.000031 loss : 0.363520 +[10:17:11.112] iteration 84600 [253.01 sec]: learning rate : 0.000031 loss : 0.303444 +[10:17:47.115] Epoch 146 Evaluation: +[10:18:34.784] average MSE: 0.10113008246939284 average PSNR: 22.963511127827783 average SSIM: 0.6392712315827422 +[10:18:48.981] iteration 84700 [14.17 sec]: learning rate : 0.000031 loss : 0.268750 +[10:19:39.027] iteration 84800 [64.22 sec]: learning rate : 0.000031 loss : 0.293461 +[10:20:28.976] iteration 84900 [114.17 sec]: learning rate : 0.000031 loss : 0.308606 +[10:21:19.035] iteration 85000 [164.22 sec]: learning rate : 0.000031 loss : 0.351887 +[10:22:09.186] iteration 85100 [214.39 sec]: learning rate : 0.000031 loss : 0.314564 +[10:22:59.388] iteration 85200 [264.58 sec]: learning rate : 0.000031 loss : 0.375060 +[10:23:23.925] Epoch 147 Evaluation: +[10:24:10.072] average MSE: 0.11316691861057489 average PSNR: 22.487910226180293 average SSIM: 0.6495968003009153 +[10:24:36.297] iteration 85300 [26.20 sec]: learning rate : 0.000031 loss : 0.278439 +[10:25:26.429] iteration 85400 [76.33 sec]: learning rate : 0.000031 loss : 0.341267 +[10:26:16.801] iteration 85500 [126.70 sec]: learning rate : 0.000031 loss : 0.283792 +[10:27:06.954] iteration 85600 [176.86 sec]: learning rate : 0.000031 loss : 0.320655 +[10:27:57.055] iteration 85700 [226.96 sec]: learning rate : 0.000031 loss : 0.331951 +[10:28:47.061] iteration 85800 [276.96 sec]: learning rate : 0.000031 loss : 0.264880 +[10:28:59.065] Epoch 148 Evaluation: +[10:29:46.967] average MSE: 0.09546411795913781 average PSNR: 23.20988393883347 average SSIM: 0.6408591591233107 +[10:30:25.282] iteration 85900 [38.29 sec]: learning rate : 0.000031 loss : 0.290288 +[10:31:15.587] iteration 86000 [88.59 sec]: learning rate : 0.000031 loss : 0.344146 +[10:32:05.611] iteration 86100 [138.62 sec]: learning rate : 0.000031 loss : 0.314769 +[10:32:55.699] iteration 86200 [188.71 sec]: learning rate : 0.000031 loss : 0.270392 +[10:33:46.137] iteration 86300 [239.14 sec]: learning rate : 0.000031 loss : 0.306747 +[10:34:36.245] iteration 86400 [289.26 sec]: learning rate : 0.000031 loss : 0.326322 +[10:34:36.291] Epoch 149 Evaluation: +[10:35:22.674] average MSE: 0.11180114202356077 average PSNR: 22.53931357426915 average SSIM: 0.648827855983111 +[10:36:13.048] iteration 86500 [50.35 sec]: learning rate : 0.000031 loss : 0.315951 +[10:37:02.958] iteration 86600 [100.26 sec]: learning rate : 0.000031 loss : 0.401824 +[10:37:53.017] iteration 86700 [150.32 sec]: learning rate : 0.000031 loss : 0.308833 +[10:38:43.124] iteration 86800 [200.42 sec]: learning rate : 0.000031 loss : 0.364716 +[10:39:33.134] iteration 86900 [250.43 sec]: learning rate : 0.000031 loss : 0.293423 +[10:40:11.758] Epoch 150 Evaluation: +[10:40:57.298] average MSE: 0.10815075290820168 average PSNR: 22.67648973024771 average SSIM: 0.6427193557889421 +[10:41:09.501] iteration 87000 [12.18 sec]: learning rate : 0.000031 loss : 0.338497 +[10:41:59.957] iteration 87100 [62.63 sec]: learning rate : 0.000031 loss : 0.371339 +[10:42:49.989] iteration 87200 [112.67 sec]: learning rate : 0.000031 loss : 0.304847 +[10:43:40.112] iteration 87300 [162.79 sec]: learning rate : 0.000031 loss : 0.305700 +[10:44:30.148] iteration 87400 [212.82 sec]: learning rate : 0.000031 loss : 0.278603 +[10:45:20.260] iteration 87500 [262.94 sec]: learning rate : 0.000031 loss : 0.381183 +[10:45:46.271] Epoch 151 Evaluation: +[10:46:33.554] average MSE: 0.09551380442232298 average PSNR: 23.208006451540623 average SSIM: 0.6460366000643422 +[10:46:57.873] iteration 87600 [24.29 sec]: learning rate : 0.000031 loss : 0.230609 +[10:47:48.122] iteration 87700 [74.54 sec]: learning rate : 0.000031 loss : 0.331521 +[10:48:38.609] iteration 87800 [125.03 sec]: learning rate : 0.000031 loss : 0.272747 +[10:49:29.069] iteration 87900 [175.49 sec]: learning rate : 0.000031 loss : 0.362386 +[10:50:19.013] iteration 88000 [225.43 sec]: learning rate : 0.000031 loss : 0.381640 +[10:51:09.070] iteration 88100 [275.49 sec]: learning rate : 0.000031 loss : 0.315227 +[10:51:23.057] Epoch 152 Evaluation: +[10:52:07.995] average MSE: 0.1041351104417591 average PSNR: 22.841175383237864 average SSIM: 0.6459038992134396 +[10:52:44.376] iteration 88200 [36.35 sec]: learning rate : 0.000031 loss : 0.235369 +[10:53:34.696] iteration 88300 [86.67 sec]: learning rate : 0.000031 loss : 0.300574 +[10:54:24.705] iteration 88400 [136.68 sec]: learning rate : 0.000031 loss : 0.336593 +[10:55:14.725] iteration 88500 [186.70 sec]: learning rate : 0.000031 loss : 0.299463 +[10:56:05.001] iteration 88600 [236.98 sec]: learning rate : 0.000031 loss : 0.322230 +[10:56:54.992] iteration 88700 [286.97 sec]: learning rate : 0.000031 loss : 0.368870 +[10:56:56.999] Epoch 153 Evaluation: +[10:57:41.717] average MSE: 0.10371318031048943 average PSNR: 22.85703229720043 average SSIM: 0.6461860130276676 +[10:58:30.320] iteration 88800 [48.58 sec]: learning rate : 0.000031 loss : 0.296571 +[10:59:20.457] iteration 88900 [98.71 sec]: learning rate : 0.000031 loss : 0.298148 +[11:00:11.077] iteration 89000 [149.33 sec]: learning rate : 0.000031 loss : 0.225822 +[11:01:01.016] iteration 89100 [199.27 sec]: learning rate : 0.000031 loss : 0.309216 +[11:01:51.081] iteration 89200 [249.34 sec]: learning rate : 0.000031 loss : 0.297765 +[11:02:31.082] Epoch 154 Evaluation: +[11:03:17.808] average MSE: 0.10048884749642287 average PSNR: 22.990020073505928 average SSIM: 0.6419146040513488 +[11:03:28.389] iteration 89300 [10.56 sec]: learning rate : 0.000031 loss : 0.327841 +[11:04:18.457] iteration 89400 [60.62 sec]: learning rate : 0.000031 loss : 0.365489 +[11:05:08.631] iteration 89500 [110.80 sec]: learning rate : 0.000031 loss : 0.286757 +[11:05:59.026] iteration 89600 [161.19 sec]: learning rate : 0.000031 loss : 0.297181 +[11:06:49.424] iteration 89700 [211.59 sec]: learning rate : 0.000031 loss : 0.384609 +[11:07:39.490] iteration 89800 [261.66 sec]: learning rate : 0.000031 loss : 0.295659 +[11:08:07.484] Epoch 155 Evaluation: +[11:08:52.068] average MSE: 0.11146284756226632 average PSNR: 22.551953396168596 average SSIM: 0.6502847736478327 +[11:09:14.253] iteration 89900 [22.16 sec]: learning rate : 0.000031 loss : 0.369095 +[11:10:04.373] iteration 90000 [72.28 sec]: learning rate : 0.000031 loss : 0.298970 +[11:10:54.789] iteration 90100 [122.69 sec]: learning rate : 0.000031 loss : 0.311708 +[11:11:44.999] iteration 90200 [172.90 sec]: learning rate : 0.000031 loss : 0.296930 +[11:12:35.018] iteration 90300 [222.92 sec]: learning rate : 0.000031 loss : 0.353140 +[11:13:25.073] iteration 90400 [272.98 sec]: learning rate : 0.000031 loss : 0.316770 +[11:13:41.062] Epoch 156 Evaluation: +[11:14:25.996] average MSE: 0.09880360514161092 average PSNR: 23.06259119298642 average SSIM: 0.6425388745549993 +[11:15:00.159] iteration 90500 [34.14 sec]: learning rate : 0.000031 loss : 0.360024 +[11:15:50.611] iteration 90600 [84.59 sec]: learning rate : 0.000031 loss : 0.358573 +[11:16:40.616] iteration 90700 [134.59 sec]: learning rate : 0.000031 loss : 0.291260 +[11:17:30.564] iteration 90800 [184.54 sec]: learning rate : 0.000031 loss : 0.354920 +[11:18:21.266] iteration 90900 [235.26 sec]: learning rate : 0.000031 loss : 0.374382 +[11:19:11.377] iteration 91000 [285.36 sec]: learning rate : 0.000031 loss : 0.300934 +[11:19:15.383] Epoch 157 Evaluation: +[11:20:01.633] average MSE: 0.09350937236218744 average PSNR: 23.300044092767582 average SSIM: 0.6480196001233081 +[11:20:47.842] iteration 91100 [46.18 sec]: learning rate : 0.000031 loss : 0.332318 +[11:21:38.008] iteration 91200 [96.35 sec]: learning rate : 0.000031 loss : 0.379484 +[11:22:28.001] iteration 91300 [146.34 sec]: learning rate : 0.000031 loss : 0.420803 +[11:23:18.218] iteration 91400 [196.56 sec]: learning rate : 0.000031 loss : 0.287267 +[11:24:08.732] iteration 91500 [247.07 sec]: learning rate : 0.000031 loss : 0.386983 +[11:24:51.121] Epoch 158 Evaluation: +[11:25:35.999] average MSE: 0.09291579769719185 average PSNR: 23.32592754966249 average SSIM: 0.6427999893715373 +[11:25:44.220] iteration 91600 [8.19 sec]: learning rate : 0.000031 loss : 0.310538 +[11:26:34.768] iteration 91700 [58.74 sec]: learning rate : 0.000031 loss : 0.336889 +[11:27:24.805] iteration 91800 [108.78 sec]: learning rate : 0.000031 loss : 0.290830 +[11:28:14.751] iteration 91900 [158.73 sec]: learning rate : 0.000031 loss : 0.283731 +[11:29:04.793] iteration 92000 [208.77 sec]: learning rate : 0.000031 loss : 0.296987 +[11:29:55.298] iteration 92100 [259.27 sec]: learning rate : 0.000031 loss : 0.273952 +[11:30:25.272] Epoch 159 Evaluation: +[11:31:09.451] average MSE: 0.09805869137897169 average PSNR: 23.095779222522914 average SSIM: 0.6436109760669413 +[11:31:29.637] iteration 92200 [20.16 sec]: learning rate : 0.000031 loss : 0.283691 +[11:32:19.692] iteration 92300 [70.21 sec]: learning rate : 0.000031 loss : 0.324993 +[11:33:09.708] iteration 92400 [120.23 sec]: learning rate : 0.000031 loss : 0.219657 +[11:34:00.392] iteration 92500 [170.91 sec]: learning rate : 0.000031 loss : 0.401106 +[11:34:50.429] iteration 92600 [220.95 sec]: learning rate : 0.000031 loss : 0.267417 +[11:35:40.368] iteration 92700 [270.89 sec]: learning rate : 0.000031 loss : 0.327598 +[11:35:58.731] Epoch 160 Evaluation: +[11:36:43.128] average MSE: 0.11181069416263793 average PSNR: 22.536673780734432 average SSIM: 0.650195514649286 +[11:37:15.311] iteration 92800 [32.16 sec]: learning rate : 0.000031 loss : 0.323939 +[11:38:05.397] iteration 92900 [82.24 sec]: learning rate : 0.000031 loss : 0.258167 +[11:38:55.359] iteration 93000 [132.20 sec]: learning rate : 0.000031 loss : 0.353527 +[11:39:45.491] iteration 93100 [182.34 sec]: learning rate : 0.000031 loss : 0.341790 +[11:40:35.579] iteration 93200 [232.43 sec]: learning rate : 0.000031 loss : 0.362424 +[11:41:25.815] iteration 93300 [282.66 sec]: learning rate : 0.000031 loss : 0.307419 +[11:41:31.817] Epoch 161 Evaluation: +[11:42:18.435] average MSE: 0.11022703269960087 average PSNR: 22.597827011580932 average SSIM: 0.6509159963042537 +[11:43:03.211] iteration 93400 [44.75 sec]: learning rate : 0.000031 loss : 0.353563 +[11:43:53.292] iteration 93500 [94.83 sec]: learning rate : 0.000031 loss : 0.306491 +[11:44:43.274] iteration 93600 [144.81 sec]: learning rate : 0.000031 loss : 0.419300 +[11:45:33.335] iteration 93700 [194.88 sec]: learning rate : 0.000031 loss : 0.298005 +[11:46:23.364] iteration 93800 [244.90 sec]: learning rate : 0.000031 loss : 0.331182 +[11:47:07.298] Epoch 162 Evaluation: +[11:47:53.983] average MSE: 0.0836403647210981 average PSNR: 23.785785823964023 average SSIM: 0.6501080758214023 +[11:48:00.206] iteration 93900 [6.20 sec]: learning rate : 0.000031 loss : 0.383435 +[11:48:50.758] iteration 94000 [56.75 sec]: learning rate : 0.000031 loss : 0.254878 +[11:49:40.786] iteration 94100 [106.78 sec]: learning rate : 0.000031 loss : 0.290183 +[11:50:30.893] iteration 94200 [156.88 sec]: learning rate : 0.000031 loss : 0.354735 +[11:51:21.315] iteration 94300 [207.31 sec]: learning rate : 0.000031 loss : 0.316008 +[11:52:11.302] iteration 94400 [257.29 sec]: learning rate : 0.000031 loss : 0.350041 +[11:52:43.431] Epoch 163 Evaluation: +[11:53:29.264] average MSE: 0.10697219485943545 average PSNR: 22.724835383019045 average SSIM: 0.6475077502434585 +[11:53:47.907] iteration 94500 [18.62 sec]: learning rate : 0.000031 loss : 0.362215 +[11:54:37.980] iteration 94600 [68.69 sec]: learning rate : 0.000031 loss : 0.306926 +[11:55:27.901] iteration 94700 [118.61 sec]: learning rate : 0.000031 loss : 0.323203 +[11:56:18.383] iteration 94800 [169.09 sec]: learning rate : 0.000031 loss : 0.249837 +[11:57:08.423] iteration 94900 [219.13 sec]: learning rate : 0.000031 loss : 0.330123 +[11:57:58.343] iteration 95000 [269.05 sec]: learning rate : 0.000031 loss : 0.368739 +[11:58:18.320] Epoch 164 Evaluation: +[11:59:03.987] average MSE: 0.10523277179065511 average PSNR: 22.796104026009573 average SSIM: 0.6536056885049797 +[11:59:34.295] iteration 95100 [30.28 sec]: learning rate : 0.000031 loss : 0.344557 +[12:00:25.191] iteration 95200 [81.18 sec]: learning rate : 0.000031 loss : 0.296907 +[12:01:15.245] iteration 95300 [131.23 sec]: learning rate : 0.000031 loss : 0.366446 +[12:02:05.366] iteration 95400 [181.37 sec]: learning rate : 0.000031 loss : 0.301082 +[12:02:55.382] iteration 95500 [231.37 sec]: learning rate : 0.000031 loss : 0.346390 +[12:03:45.779] iteration 95600 [281.77 sec]: learning rate : 0.000031 loss : 0.308911 +[12:03:53.768] Epoch 165 Evaluation: +[12:04:38.681] average MSE: 0.10967524919625582 average PSNR: 22.619149785429407 average SSIM: 0.651124695168487 +[12:05:21.089] iteration 95700 [42.38 sec]: learning rate : 0.000031 loss : 0.375290 +[12:06:11.454] iteration 95800 [92.75 sec]: learning rate : 0.000031 loss : 0.299627 +[12:07:01.472] iteration 95900 [142.77 sec]: learning rate : 0.000031 loss : 0.411125 +[12:07:51.586] iteration 96000 [192.88 sec]: learning rate : 0.000031 loss : 0.346892 +[12:08:41.531] iteration 96100 [242.82 sec]: learning rate : 0.000031 loss : 0.282223 +[12:09:28.190] Epoch 166 Evaluation: +[12:10:13.350] average MSE: 0.10784325543466491 average PSNR: 22.690094218129218 average SSIM: 0.652441842856129 +[12:10:17.554] iteration 96200 [4.18 sec]: learning rate : 0.000031 loss : 0.338779 +[12:11:07.996] iteration 96300 [54.62 sec]: learning rate : 0.000031 loss : 0.344550 +[12:11:58.204] iteration 96400 [104.83 sec]: learning rate : 0.000031 loss : 0.344906 +[12:12:48.248] iteration 96500 [154.87 sec]: learning rate : 0.000031 loss : 0.322932 +[12:13:38.256] iteration 96600 [204.88 sec]: learning rate : 0.000031 loss : 0.471722 +[12:14:28.330] iteration 96700 [254.95 sec]: learning rate : 0.000031 loss : 0.287626 +[12:15:02.498] Epoch 167 Evaluation: +[12:15:49.884] average MSE: 0.09273840173352034 average PSNR: 23.335257429504942 average SSIM: 0.6427813273360473 +[12:16:06.107] iteration 96800 [16.20 sec]: learning rate : 0.000031 loss : 0.282684 +[12:16:56.079] iteration 96900 [66.17 sec]: learning rate : 0.000031 loss : 0.322044 +[12:17:46.235] iteration 97000 [116.32 sec]: learning rate : 0.000031 loss : 0.301621 +[12:18:37.612] iteration 97100 [167.71 sec]: learning rate : 0.000031 loss : 0.336936 +[12:19:27.627] iteration 97200 [217.72 sec]: learning rate : 0.000031 loss : 0.282997 +[12:20:17.693] iteration 97300 [267.78 sec]: learning rate : 0.000031 loss : 0.314091 +[12:20:39.706] Epoch 168 Evaluation: +[12:21:25.187] average MSE: 0.09692045206337448 average PSNR: 23.145580953612725 average SSIM: 0.6488164200165792 +[12:21:53.573] iteration 97400 [28.36 sec]: learning rate : 0.000031 loss : 0.229106 +[12:22:43.565] iteration 97500 [78.35 sec]: learning rate : 0.000031 loss : 0.238151 +[12:23:33.585] iteration 97600 [128.37 sec]: learning rate : 0.000031 loss : 0.387978 +[12:24:23.944] iteration 97700 [178.73 sec]: learning rate : 0.000031 loss : 0.329700 +[12:25:13.952] iteration 97800 [228.74 sec]: learning rate : 0.000031 loss : 0.362750 +[12:26:04.524] iteration 97900 [279.31 sec]: learning rate : 0.000031 loss : 0.354177 +[12:26:14.521] Epoch 169 Evaluation: +[12:26:59.216] average MSE: 0.08775630280386011 average PSNR: 23.574454322383666 average SSIM: 0.6437299382070473 +[12:27:39.663] iteration 98000 [40.42 sec]: learning rate : 0.000031 loss : 0.354343 +[12:28:29.766] iteration 98100 [90.52 sec]: learning rate : 0.000031 loss : 0.375580 +[12:29:19.773] iteration 98200 [140.53 sec]: learning rate : 0.000031 loss : 0.341002 +[12:30:09.945] iteration 98300 [190.70 sec]: learning rate : 0.000031 loss : 0.326393 +[12:30:59.954] iteration 98400 [240.71 sec]: learning rate : 0.000031 loss : 0.289447 +[12:31:48.051] Epoch 170 Evaluation: +[12:32:32.894] average MSE: 0.10630195387367858 average PSNR: 22.7511607135772 average SSIM: 0.6512774146238157 +[12:32:35.108] iteration 98500 [2.19 sec]: learning rate : 0.000031 loss : 0.303655 +[12:33:25.023] iteration 98600 [52.10 sec]: learning rate : 0.000031 loss : 0.293175 +[12:34:15.456] iteration 98700 [102.54 sec]: learning rate : 0.000031 loss : 0.419576 +[12:35:05.376] iteration 98800 [152.46 sec]: learning rate : 0.000031 loss : 0.238848 +[12:35:55.386] iteration 98900 [202.46 sec]: learning rate : 0.000031 loss : 0.374521 +[12:36:46.054] iteration 99000 [253.13 sec]: learning rate : 0.000031 loss : 0.312633 +[12:37:22.022] Epoch 171 Evaluation: +[12:38:07.335] average MSE: 0.09721208119066452 average PSNR: 23.13176650263914 average SSIM: 0.6460555638157585 +[12:38:21.531] iteration 99100 [14.17 sec]: learning rate : 0.000031 loss : 0.257370 +[12:39:11.606] iteration 99200 [64.25 sec]: learning rate : 0.000031 loss : 0.322189 +[12:40:01.632] iteration 99300 [114.27 sec]: learning rate : 0.000031 loss : 0.305797 +[12:40:51.625] iteration 99400 [164.29 sec]: learning rate : 0.000031 loss : 0.343990 +[12:41:42.026] iteration 99500 [214.67 sec]: learning rate : 0.000031 loss : 0.337829 +[12:42:32.702] iteration 99600 [265.34 sec]: learning rate : 0.000031 loss : 0.332146 +[12:42:56.705] Epoch 172 Evaluation: +[12:43:41.832] average MSE: 0.11603628866195552 average PSNR: 22.383733196006254 average SSIM: 0.6541458589046161 +[12:44:08.058] iteration 99700 [26.20 sec]: learning rate : 0.000031 loss : 0.314870 +[12:44:58.153] iteration 99800 [76.29 sec]: learning rate : 0.000031 loss : 0.336904 +[12:45:48.514] iteration 99900 [126.66 sec]: learning rate : 0.000031 loss : 0.320834 +[12:46:38.600] iteration 100000 [176.74 sec]: learning rate : 0.000008 loss : 0.313221 +[12:46:38.759] save model to model/FSMNet_m4raw_4x_lr5e-4_t5_new_kspace_time/iter_100000.pth +[12:46:39.273] Epoch 173 Evaluation: +[12:47:24.548] average MSE: 0.11317819169747931 average PSNR: 22.487750321826052 average SSIM: 0.653872191029069 +[12:47:24.877] save model to model/FSMNet_m4raw_4x_lr5e-4_t5_new_kspace_time/iter_100000.pth +===> Evaluate Metric <=== +Results +------------------------------------ +ColdDiffusion NMSE: 1.1031 ± 0.0725 +ColdDiffusion PSNR: 33.5017 ± 0.4869 +ColdDiffusion SSIM: 0.8979 ± 0.0076 +------------------------------------ +All NMSE: 1.1008 ± 0.1699 +All PSNR: 32.4622 ± 0.8115 +All SSIM: 0.8831 ± 0.0128 +------------------------------------ \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t5_new_kspace_time/log/events.out.tfevents.1752550634.GCRSANDBOX133 b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t5_new_kspace_time/log/events.out.tfevents.1752550634.GCRSANDBOX133 new file mode 100644 index 0000000000000000000000000000000000000000..ba7ce7ecfb0447b0f3d4365e4af5bb7e2e27b0fa --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/model/FSMNet_m4raw_4x_lr5e-4_t5_new_kspace_time/log/events.out.tfevents.1752550634.GCRSANDBOX133 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16fc9eed63fc301cf5f4ffe9e7238167f0ab46b14c5cd3ed7145c75b8e3a9451 +size 40 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/__init__.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/__pycache__/__init__.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ddc6d915a070e5e306a8b9cc747ba4af715050d Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/__pycache__/__init__.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/__pycache__/common_freq.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/__pycache__/common_freq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7676d3dc7e470abf53c7bf98d508e710bf417252 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/__pycache__/common_freq.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/__pycache__/mynet.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/__pycache__/mynet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66c7c31ec0d36f2b986f4fd6520cedfa15236c00 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/__pycache__/mynet.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/common_freq.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/common_freq.py new file mode 100644 index 0000000000000000000000000000000000000000..79cf3e778029a846b4da910c115c8315bf33dbaf --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/common_freq.py @@ -0,0 +1,389 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange +class Scale(nn.Module): + + def __init__(self, init_value=1e-3): + super().__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale + + +class invPixelShuffle(nn.Module): + def __init__(self, ratio=2): + super(invPixelShuffle, self).__init__() + self.ratio = ratio + + def forward(self, tensor): + ratio = self.ratio + b = tensor.size(0) + ch = tensor.size(1) + y = tensor.size(2) + x = tensor.size(3) + assert x % ratio == 0 and y % ratio == 0, 'x, y, ratio : {}, {}, {}'.format(x, y, ratio) + tensor = tensor.view(b, ch, y // ratio, ratio, x // ratio, ratio).permute(0, 1, 3, 5, 2, 4) + return tensor.contiguous().view(b, -1, y // ratio, x // ratio) + + +class UpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(in_channels=n_feats, out_channels=4 * n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PixelShuffle(upscale_factor=2)) + m.append(nn.PReLU()) + super(UpSampler, self).__init__(*m) + + +class InvUpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(invPixelShuffle(2)) + m.append(nn.Conv2d(in_channels=n_feats * 4, out_channels=n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PReLU()) + super(InvUpSampler, self).__init__(*m) + + + +class AdaptiveNorm(nn.Module): + def __init__(self, n): + super(AdaptiveNorm, self).__init__() + + self.w_0 = nn.Parameter(torch.Tensor([1.0])) + self.w_1 = nn.Parameter(torch.Tensor([0.0])) + + self.bn = nn.BatchNorm2d(n, momentum=0.999, eps=0.001) + + def forward(self, x): + return self.w_0 * x + self.w_1 * self.bn(x) + + +class ConvBNReLU2D(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, + bias=False, act=None, norm=None): + super(ConvBNReLU2D, self).__init__() + + self.layers = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + self.act = None + self.norm = None + if norm == 'BN': + self.norm = torch.nn.BatchNorm2d(out_channels) + elif norm == 'IN': + self.norm = torch.nn.InstanceNorm2d(out_channels) + elif norm == 'GN': + self.norm = torch.nn.GroupNorm(2, out_channels) + elif norm == 'WN': + self.layers = torch.nn.utils.weight_norm(self.layers) + elif norm == 'Adaptive': + self.norm = AdaptiveNorm(n=out_channels) + + if act == 'PReLU': + self.act = torch.nn.PReLU() + elif act == 'SELU': + self.act = torch.nn.SELU(True) + elif act == 'LeakyReLU': + self.act = torch.nn.LeakyReLU(negative_slope=0.02, inplace=True) + elif act == 'ELU': + self.act = torch.nn.ELU(inplace=True) + elif act == 'ReLU': + self.act = torch.nn.ReLU(True) + elif act == 'Tanh': + self.act = torch.nn.Tanh() + elif act == 'Sigmoid': + self.act = torch.nn.Sigmoid() + elif act == 'SoftMax': + self.act = torch.nn.Softmax2d() + + def forward(self, inputs): + + out = self.layers(inputs) + + if self.norm is not None: + out = self.norm(out) + if self.act is not None: + out = self.act(out) + return out + + +class ResBlock(nn.Module): + def __init__(self, num_features): + super(ResBlock, self).__init__() + self.layers = nn.Sequential( + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, act='ReLU', padding=1), + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, padding=1) + ) + + def forward(self, inputs): + return F.relu(self.layers(inputs) + inputs) + + +class DownSample(nn.Module): + def __init__(self, num_features, act, norm, scale=2): + super(DownSample, self).__init__() + if scale == 1: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + else: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + invPixelShuffle(ratio=scale), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + + def forward(self, inputs): + return self.layers(inputs) + + +class ResidualGroup(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act,norm, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [ + ResBlock(n_feat) for _ in range(n_resblocks)] + + modules_body.append(ConvBNReLU2D(n_feat, n_feat, kernel_size, padding=1, act=act, norm=norm)) + self.body = nn.Sequential(*modules_body) + self.re_scale = Scale(1) + + def forward(self, x): + res = self.body(x) + return res + self.re_scale(x) + + + +class FreBlock9(nn.Module): + def __init__(self, channels, args): + super(FreBlock9, self).__init__() + + self.fpre = nn.Conv2d(channels, channels, 1, 1, 0) + self.amp_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.pha_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.post = nn.Conv2d(channels, channels, 1, 1, 0) + + + def forward(self, x): + # print("x: ", x.shape) + _, _, H, W = x.shape + msF = torch.fft.rfft2(self.fpre(x)+1e-8, norm='backward') + + msF_amp = torch.abs(msF) + msF_pha = torch.angle(msF) + # print("msf_amp: ", msF_amp.shape) + amp_fuse = self.amp_fuse(msF_amp) + # print(amp_fuse.shape, msF_amp.shape) + amp_fuse = amp_fuse + msF_amp + pha_fuse = self.pha_fuse(msF_pha) + pha_fuse = pha_fuse + msF_pha + + real = amp_fuse * torch.cos(pha_fuse)+1e-8 + imag = amp_fuse * torch.sin(pha_fuse)+1e-8 + out = torch.complex(real, imag)+1e-8 + out = torch.abs(torch.fft.irfft2(out, s=(H, W), norm='backward')) + out = self.post(out) + out = out + x + out = torch.nan_to_num(out, nan=1e-5, posinf=1e-5, neginf=1e-5) + # print("out: ", out.shape) + return out + +class Attention(nn.Module): + def __init__(self, dim=64, num_heads=8, bias=False): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) + self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias) + self.q = nn.Conv2d(dim, dim , kernel_size=1, bias=bias) + self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x, y): + b, c, h, w = x.shape + + kv = self.kv_dwconv(self.kv(y)) + k, v = kv.chunk(2, dim=1) + q = self.q_dwconv(self.q(x)) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + +class FuseBlock7(nn.Module): + def __init__(self, channels): + super(FuseBlock7, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fre_att = Attention(dim=channels) + self.spa_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + ori = spa + fre = self.fre(fre) + spa = self.spa(spa) + fre = self.fre_att(fre, spa)+fre + spa = self.fre_att(spa, fre)+spa + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class FuseBlock6(nn.Module): + def __init__(self, channels): + super(FuseBlock6, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock7(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock7, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t1_att = Attention(dim=channels) + self.t2_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + t1 = self.t1_att(t1, t2)+t1 + t2 = self.t2_att(t2, t1)+t2 + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock6(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock6, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/ARTUnet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/ARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ae23fed16b0d9454a5ea09f17b4322f1e9d7e928 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/ARTUnet.py @@ -0,0 +1,751 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + # print("x shape:", x.shape) + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +########################################################################## +class CS_TransformerBlock(nn.Module): + r""" 在ART Transformer Block的基础上添加了channel attention的分支。 + 参考论文: Dual Aggregation Transformer for Image Super-Resolution + 及其代码: https://github.com/zhengchen1999/DAT/blob/main/basicsr/archs/dat_arch.py + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + + self.dwconv = nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim), + nn.InstanceNorm2d(dim), + nn.GELU()) + + self.channel_interaction = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(dim, dim // 8, kernel_size=1), + # nn.InstanceNorm2d(dim // 8), + nn.GELU(), + nn.Conv2d(dim // 8, dim, kernel_size=1)) + + self.spatial_interaction = nn.Sequential( + nn.Conv2d(dim, dim // 16, kernel_size=1), + nn.InstanceNorm2d(dim // 16), + nn.GELU(), + nn.Conv2d(dim // 16, 1, kernel_size=1)) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # print("x before dwconv:", x.shape) + # convolution output + conv_x = self.dwconv(x.permute(0, 3, 1, 2)) + # conv_x = x.permute(0, 3, 1, 2) + # print("x after dwconv:", x.shape) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + attened_x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + attened_x = attened_x[:, :H, :W, :].contiguous() + + attened_x = attened_x.view(B, H * W, C) + + # Adaptive Interaction Module (AIM) + # C-Map (before sigmoid) + channel_map = self.channel_interaction(conv_x).permute(0, 2, 3, 1).contiguous().view(B, 1, C) + # S-Map (before sigmoid) + attention_reshape = attened_x.transpose(-2,-1).contiguous().view(B, C, H, W) + spatial_map = self.spatial_interaction(attention_reshape) + + # C-I + attened_x = attened_x * torch.sigmoid(channel_map) + # S-I + conv_x = torch.sigmoid(spatial_map) * conv_x + conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, H * W, C) + + x = attened_x + conv_x + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/ARTUnet_new.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/ARTUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9d210d46e2a4d73781b3b16b63a53fdc0f25b9 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/ARTUnet_new.py @@ -0,0 +1,823 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +Unet的格式构建image restoration 网络。每个transformer block中都是dense self-attention和sparse self-attention交替连接。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True, visualize_attention_maps=False): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + self.visualize_attention_maps = visualize_attention_maps + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + # assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + # qkv: 3, B_, num_heads, N, C // num_heads + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + # B_, num_heads, N, C // num_heads * + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + attn_map = attn.detach().cpu().numpy().mean(axis=1) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) # (B * nP, nHead, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + # attn: B * nP, num_heads, N, N + # v: B * nP, num_heads, N, C // num_heads + # -attn-@-v-> B * nP, num_heads, N, C // num_heads + # -transpose-> B * nP, N, num_heads, C // num_heads + # -reshape-> B * nP, N, C + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, G ** 2, G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, Gh * Gw, Gh * Gw) + else: + x = self.attn(x, Gh, Gw, mask=attn_mask) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +class ConcatTransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + ################## + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + ################## + + x = torch.cat((x, y), dim=1) + + shortcut = x + x = self.norm1(x) + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, 2 * Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, 2 * G ** 2, 2 * G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, 2 * Gh * Gw, 2 * Gh * Gw) + else: + x = self.attn(x, 2 * Gh, Gw, mask=attn_mask) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + output = x + + # print(x.shape, Gh, Gw) + x = output[:, :Gh * Gw, :] + y = output[:, Gh * Gw:, :] + assert x.shape == y.shape + + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = y.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = y.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, y, attn_map + else: + return x, y + + def get_patches(self, x, x_size): + H, W = x_size + B, H, W, C = x.shape + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + return x, attn_mask, (Hd, Wd), (Gh, Gw), (pad_r, pad_b), G, I + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/ART_Restormer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/ART_Restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..4c23123f0ac1a8e82e2bf907658ce8d91951c3e4 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/ART_Restormer.py @@ -0,0 +1,592 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer串联而成。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[4, 4, 4, 4], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + + ######################################### Encoder level1 ############################################ + ### ART encoder block + for layer in self.ART_encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + + ### Restormer encoder block + out_enc_level1 = rearrange(out_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = self.Restormer_encoder_level1(out_enc_level1) + # stack.append(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.ART_encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + out_enc_level2 = rearrange(out_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = self.Restormer_encoder_level2(out_enc_level2) + # stack.append(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.ART_encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + out_enc_level3 = rearrange(out_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = self.Restormer_encoder_level3(out_enc_level3) + # stack.append(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.ART_latent: + latent = layer(latent, [H // 8, W // 8]) + + ### Restormer encoder block + latent = rearrange(latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = self.Restormer_latent(latent) + # stack.append(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + out_dec_level3 = inp_dec_level3 + for layer in self.ART_decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + out_dec_level3 = rearrange(out_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + out_dec_level3 = self.Restormer_decoder_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + out_dec_level2 = inp_dec_level2 + for layer in self.ART_decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + out_dec_level2 = rearrange(out_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + out_dec_level2 = self.Restormer_decoder_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + out_dec_level1 = inp_dec_level1 + for layer in self.ART_decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + ### Restormer decoder block + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_decoder_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_refinement(out_dec_level1) + + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/ART_Restormer_v2.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/ART_Restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..de8f921e81d79cb7af14a75326f0ddaaf3b3f4c3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/ART_Restormer_v2.py @@ -0,0 +1,663 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer并联而成。 +输入特征分别经过ART和Restormer两个分支提取特征, 将两者的特征concat之后,用conv层变换得到该block的输出特征。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.enc_fuse_level1 = nn.Conv2d(int(dim * 2 ** 1), dim, kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.enc_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.enc_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + self.enc_fuse_level4 = nn.Conv2d(int(dim * 2 ** 4), int(dim * 2 ** 3), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.dec_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.dec_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.dec_fuse_level1 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + # self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + # bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + # self.fuse_refinement = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + # def forward(self, inp_img, aux_img): + # stack = [] + # bs, _, H, W = inp_img.shape + + # fuse_img = torch.cat((inp_img, aux_img), 1) + # inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + + def forward(self, inp_img): + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + + ######################################### Encoder level1 ############################################ + art_enc_level1 = inp_enc_level1 + restor_enc_level1 = inp_enc_level1 + + ### ART encoder block + for layer in self.ART_encoder_level1: + art_enc_level1 = layer(art_enc_level1, [H, W]) + + ### Restormer encoder block + restor_enc_level1 = rearrange(restor_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_enc_level1 = self.Restormer_encoder_level1(restor_enc_level1) ### (b, c, h, w) + + art_enc_level1 = rearrange(art_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = torch.cat((art_enc_level1, restor_enc_level1), 1) + out_enc_level1 = self.enc_fuse_level1(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + art_enc_level2 = inp_enc_level2 + restor_enc_level2 = inp_enc_level2 + + for layer in self.ART_encoder_level2: + art_enc_level2 = layer(art_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + restor_enc_level2 = rearrange(restor_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + restor_enc_level2 = self.Restormer_encoder_level2(restor_enc_level2) + + art_enc_level2 = rearrange(art_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = torch.cat((art_enc_level2, restor_enc_level2), 1) + out_enc_level2 = self.enc_fuse_level2(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + art_enc_level3 = inp_enc_level3 + restor_enc_level3 = inp_enc_level3 + + for layer in self.ART_encoder_level3: + art_enc_level3 = layer(art_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + restor_enc_level3 = rearrange(restor_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + restor_enc_level3 = self.Restormer_encoder_level3(restor_enc_level3) + + art_enc_level3 = rearrange(art_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = torch.cat((art_enc_level3, restor_enc_level3), 1) + out_enc_level3 = self.enc_fuse_level3(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + art_latent = inp_enc_level4 + restor_latent = inp_enc_level4 + + for layer in self.ART_latent: + art_latent = layer(art_latent, [H // 8, W // 8]) + + ### Restormer encoder block + restor_latent = rearrange(restor_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + restor_latent = self.Restormer_latent(restor_latent) + + art_latent = rearrange(art_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = torch.cat((art_latent, restor_latent), 1) + latent = self.enc_fuse_level4(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + art_dec_level3 = inp_dec_level3 + restor_dec_level3 = inp_dec_level3 + + for layer in self.ART_decoder_level3: + art_dec_level3 = layer(art_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + restor_dec_level3 = rearrange(restor_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + restor_dec_level3 = self.Restormer_decoder_level3(restor_dec_level3) + + art_dec_level3 = rearrange(art_dec_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_dec_level3 = torch.cat((art_dec_level3, restor_dec_level3), 1) + out_dec_level3 = self.dec_fuse_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + art_dec_level2 = inp_dec_level2 + restor_dec_level2 = inp_dec_level2 + + for layer in self.ART_decoder_level2: + art_dec_level2 = layer(art_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + restor_dec_level2 = rearrange(restor_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + restor_dec_level2 = self.Restormer_decoder_level2(restor_dec_level2) + + art_dec_level2 = rearrange(art_dec_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_dec_level2 = torch.cat((art_dec_level2, restor_dec_level2), 1) + out_dec_level2 = self.dec_fuse_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + art_dec_level1 = inp_dec_level1 + restor_dec_level1 = inp_dec_level1 + + for layer in self.ART_decoder_level1: + art_dec_level1 = layer(art_dec_level1, [H, W]) + + ### Restormer decoder block + restor_dec_level1 = rearrange(restor_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_dec_level1 = self.Restormer_decoder_level1(restor_dec_level1) + + art_dec_level1 = rearrange(art_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = torch.cat((art_dec_level1, restor_dec_level1), 1) + out_dec_level1 = self.dec_fuse_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +# def build_model(args): +# return ART_Restormer( +# inp_channels=2, +# out_channels=1, +# dim=32, +# num_art_blocks=[2, 4, 4, 6], +# num_restormer_blocks=[2, 2, 2, 2], +# num_refinement_blocks=4, +# heads=[1, 2, 4, 8], +# window_size=[10, 10, 10, 10], mlp_ratio=4., +# qkv_bias=True, qk_scale=None, +# interval=[24, 12, 6, 3], +# ffn_expansion_factor = 2.66, +# LayerNorm_type = 'WithBias', ## Other option 'BiasFree' +# bias=False) + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/ARTfuse_layer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/ARTfuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..78af0b27528b50db897936f6b8b4eee51d566b1c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/ARTfuse_layer.py @@ -0,0 +1,705 @@ +""" +ART layer in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input to the Attention layer:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + # print("q range:", q.max(), q.min()) + # print("k range:", k.max(), k.min()) + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + # print("attention matrix:", attn.shape, attn) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + # print("feature before proj-layer:", x.max(), x.min()) + x = self.proj(x) + # print("feature after proj-layer:", x.max(), x.min()) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pre_norm=True): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + # self.conv = nn.Conv2d(dim, dim, kernel_size=1) + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + # self.conv = ConvBlock(self.dim, self.dim, drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.bn_norm = nn.BatchNorm2d(dim) + self.pre_norm = pre_norm + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + # x = self.conv(x) + shortcut = x + if self.pre_norm: + # print("using pre_norm") + x = self.norm1(x) ## 归一化之后特征范围变得非常小 + x = x.view(B, H, W, C) + # print("normalized input range:", x.max(), x.min(), x.mean()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + # print("range of feature before layer:", x.max(), x.min(), x.mean()) + # x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + + # x_bn = self.bn_norm(x.permute(0, 3, 1, 2)) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + # print("[ART layer] input and output feature difference:", torch.mean(torch.abs(shortcut - x))) + # print("range of feature before ART layer:", shortcut.max(), shortcut.min(), shortcut.mean()) + # print("range of feature after ART layer:", x.max(), x.min(), x.mean()) + if self.pre_norm: + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + # print("using post_norm") + x = self.norm1(shortcut + self.drop_path(x)) + x = self.norm2(x + self.drop_path(self.mlp(x))) + # x = shortcut + self.drop_path(x) + # x = x + self.drop_path(self.mlp(x)) + + # print("range of fused feature:", x.max(), x.min()) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + return self.layers(image) + + + +########################################################################## +class Cross_TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +class Cross_TransformerBlock_v2(nn.Module): + r""" ART Transformer Block. + 将两个模态的特征沿着channel方向concat, 一起取window. 之后把取出来的两个模态中同一个window的所有patches合并,送到transformer处理。 + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + xy = torch.cat((x, y), -1) + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + xy = F.pad(xy, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + xy = xy.reshape(B, Hd // G, G, Wd // G, G, 2*C).permute(0, 1, 3, 2, 4, 5).contiguous() + xy = xy.reshape(B * Hd * Wd // G ** 2, G ** 2, 2*C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + xy = xy.reshape(B, Gh, I, Gw, I, 2*C).permute(0, 2, 4, 1, 3, 5).contiguous() + xy = xy.reshape(B * I * I, Gh * Gw, 2*C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + x = xy[:, :, :C] + y = xy[:, :, C:] + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/DataConsistency.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/DataConsistency.py new file mode 100644 index 0000000000000000000000000000000000000000..5f460e9acabcb33e1a3f812eb9acaeb337da446c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/DataConsistency.py @@ -0,0 +1,48 @@ +""" +Created: DataConsistency @ Xiyang Cai, 2023/09/09 + +Data consistency layer for k-space signal. + +Ref: DataConsistency in DuDoRNet (https://github.com/bbbbbbzhou/DuDoRNet) + +""" + +import torch +from torch import nn +from einops import repeat + + +def data_consistency(k, k0, mask): + """ + k - input in k-space + k0 - initially sampled elements in k-space + mask - corresponding nonzero location + """ + out = (1 - mask) * k + mask * k0 + return out + + +class DataConsistency(nn.Module): + """ + Create data consistency operator + """ + + def __init__(self): + super(DataConsistency, self).__init__() + + def forward(self, k, k0, mask): + """ + k - input in frequency domain, of shape (n, nx, ny, 2) + k0 - initially sampled elements in k-space + mask - corresponding nonzero location (n, 1, len, 1) + """ + + if k.dim() != 4: # input is 2D + raise ValueError("error in data consistency layer!") + + # mask = repeat(mask.squeeze(1, 3), 'b x -> b x y c', y=k.shape[1], c=2) + mask = torch.tile(mask, (1, mask.shape[2], 1, k.shape[-1])) ### [n, 320, 320, 2] + # print("k and k0 shape:", k.shape, k0.shape) + out = data_consistency(k, k0, mask) + + return out, mask diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/DuDo_mUnet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/DuDo_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a56069a27f0cdd392482b41dc4e3b79252295487 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/DuDo_mUnet.py @@ -0,0 +1,359 @@ +""" +Xiaohan Xing, 2023/10/25 +传入两个模态的kspace data, 经过kspace network进行重建。 +然后把重建之后的kspace data变换回图像,然后送到image domain的网络进行图像重建。 +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + +from .Kspace_mUnet import kspace_mmUnet + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = kspace_ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = kspace_ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = kspace_ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = kspace_ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class DuDo_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 3, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + # self.kspace_net = mmConvKSpace() + self.kspace_net = kspace_mmUnet(args) + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + ) + ) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + stack = [] + feature_stack = [] + # output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + output = torch.cat((image, LR_image, aux_image), 1) #.detach() + + # print("LR image range:", LR_image.max(), LR_image.min()) + # print("kspace_recon image range:", image.max(), image.min()) + # print("aux_image range:", aux_image.max(), aux_image.min()) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output, image, recon_kspace, masked_kspace, data_stats + + + +class kspace_ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return DuDo_mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/DuDo_mUnet_ARTfusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/DuDo_mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6df67b63140ca51bd7b8664151ab4b1b7a6305d5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/DuDo_mUnet_ARTfusion.py @@ -0,0 +1,369 @@ +""" +Xiaohan Xing, 2023/11/08 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + ARTfusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + output1, output2 = aux_image.detach(), LR_image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, image, recon_kspace, masked_kspace, data_stats + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/DuDo_mUnet_CatFusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/DuDo_mUnet_CatFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c2b86b6e2031b1af07ab7cb58efc182b7a2673 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/DuDo_mUnet_CatFusion.py @@ -0,0 +1,338 @@ +""" +Xiaohan Xing, 2023/11/10 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + multi layer concat fusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_CATfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t2_gt: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + t2_image = t2_gt * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + # # print("aux_image range:", aux_image.max(), aux_image.min()) + # print("LR_image range:", LR_image.max(), LR_image.min()) + # print("kspace recon_img range:", image.max(), image.min()) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + t2_gt_image = self.normalize(t2_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + t2_gt_image = torch.clamp(t2_gt_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + # output1, output2 = aux_image.detach(), LR_image.detach() + output1, output2 = aux_image.detach(), image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, aux_image, t2_gt_image, image, LR_image, recon_kspace, masked_kspace, data_stats + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_CATfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/Kspace_ConvNet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/Kspace_ConvNet.py new file mode 100644 index 0000000000000000000000000000000000000000..918cbe598a62653394c8c4eaf99cd10a45c064ca --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/Kspace_ConvNet.py @@ -0,0 +1,140 @@ +""" +2023/10/05 +Xiaohan Xing +Build a simple network with several convolution layers. +Input: concat (undersampled kspace of FS-PD, fully-sampled kspace of PD modality) +Output: interpolated kspace of FS-PD +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, args, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + # self.conv_blocks = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # self.conv_blocks.append(ConvBlock(self.chans, self.chans * 2, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans * 2, self.chans, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans, self.out_chans, drop_prob)) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + # print("output of the three conv blocks:", output1.shape, output2.shape, output3.shape) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("conv1_output range:", output1.max(), output1.min()) + # print("conv2_output range:", output2.max(), output2.min()) + # print("conv3_output range:", output3.max(), output3.min()) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + # print("output before DC layer:", output.shape) + # print("output range before DC layer:", output.max(), output.min()) + # output = output + kspace ## residual connection + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + # mask_output = output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +def build_model(args): + return mmConvKSpace(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/Kspace_mUnet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/Kspace_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2db9a357798be4abd4cfbd44e9207555770b0456 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/Kspace_mUnet.py @@ -0,0 +1,202 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # print("output:", output.shape) + # print("kspace:", kspace.shape) + # print("mask:", mask.shape) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/Kspace_mUnet_new.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/Kspace_mUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ccfd28dc198ffeff2259a5fbed45dabdc76819 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/Kspace_mUnet_new.py @@ -0,0 +1,200 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # print("using Unet without Relu layers") + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + # output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # # print("output:", output.shape) + # # print("kspace:", kspace.shape) + # # print("mask:", mask.shape) + # mask_output = mask * output + + return output # .permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/MINet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..4eeeb063b12be1dc594e8e076e6a59e454fcfa0b --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/MINet.py @@ -0,0 +1,356 @@ +""" +Compare with the method "Multi-contrast MRI Super-Resolution via a Multi-stage Integration Network". +The code is from https://github.com/chunmeifeng/MINet/edit/main/fastmri/models/MINet.py. +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F + + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 1 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + # print("modules_tail:", modules_tail) + # self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Sigmoid()) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + + + +class MINet(nn.Module): + def __init__(self, args, n_resgroups=3, n_resblocks=3, n_feats=64): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + # print("x1:", x1.shape, "x2:", x2.shape) + + ### 源代码中是super-resolution任务,所以self.tail对图像进行unsampling. 我们做reconstruction把这步去掉 + x2 = self.tail(x2) + + + resT1 = x1 + resT2 = x2 + + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + # print("t1s:", [item.shape for item in t1s]) + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + # print("resT1 range:", resT1.max(), resT1.min()) + # print("resT2 range:", resT2.max(), resT2.min()) + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1, x2 #x1=pd x2=pdfs + + +def build_model(args): + return MINet(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/MINet_common.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/MINet_common.py new file mode 100644 index 0000000000000000000000000000000000000000..631f3036db4edf8eab00d70c66d2058a3b65cd59 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/MINet_common.py @@ -0,0 +1,90 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias) + +class MeanShift(nn.Conv2d): + def __init__( + self, rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): + + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std + for p in self.parameters(): + p.requires_grad = False + +class BasicBlock(nn.Sequential): + def __init__( + self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, + bn=True, act=nn.ReLU(True)): + + m = [conv(in_channels, out_channels, kernel_size, bias=bias)] + if bn: + m.append(nn.BatchNorm2d(out_channels)) + if act is not None: + m.append(act) + + super(BasicBlock, self).__init__(*m) + +class ResBlock(nn.Module): + def __init__( + self, conv, n_feats, kernel_size, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + + return res + + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + # print("upsampler:", m) + + super(Upsampler, self).__init__(*m) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/SANet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/SANet.py new file mode 100644 index 0000000000000000000000000000000000000000..bfac3590e248c9532b002885c6c0b7a55ab1400a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/SANet.py @@ -0,0 +1,392 @@ +""" +This script contains the codes of "Exploring Separable Attention for Multi-Contrast MR Image Super-Resolution (TNNLS 2023)" +https://github.com/chunmeifeng/SANet/blob/main/fastmri/models/SANet.py +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import os + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,scale,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups #10 + self.n_resblocks = n_resblocks #20 + self.n_feats = n_feats #64 + kernel_size = 3 + reduction = 16 #16 + # scale = args.scale[0] + + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class Seatt(nn.Module): + def __init__(self, in_c): + super(Seatt, self).__init__() + self.reduce = nn.Conv2d(in_c * 2, 32, 1) + self.ff_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.bf_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.rgbd_pred_layer = Pred_Layer(32 * 2) + self.convq = nn.Conv2d(32, 32, 3, 1, 1) + + def forward(self, rgb_feat, dep_feat, pred): + feat = torch.cat((rgb_feat, dep_feat), 1) + + # vis_attention_map_rgb_feat(rgb_feat[0][0]) + # vis_attention_map_dep_feat(dep_feat[0][0]) + # vis_attention_map_feat(feat[0][0]) + + + feat = self.reduce(feat) + [_, _, H, W] = feat.size() + pred = torch.sigmoid(pred) + + ni_pred = 1 - pred + + ff_feat = self.ff_conv(feat * pred) + bf_feat = self.bf_conv(feat * (1 - pred)) + new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1)) + + return new_pred + +class SANet(nn.Module): + def __init__(self, args, scale=1, n_resgroups=3, n_resblocks=3, n_feats=64): + super(SANet, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + + cs = [64,64,64,64] + self.Seatts = nn.ModuleList([Seatt(c) for c in cs]) + + self.net1 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + print("nlayer:",nlayer) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=3, padding=1) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2, Seatt, map_conv in zip(self.net1.body._modules.items(), self.net2.body._modules.items(), self.Seatts, self.map_convs): + + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + pred = map_conv(resT1) + res = Seatt(resT1,resT2,pred) + + resT2 = res+resT2 + + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + + return x1,x2 + + +def build_model(args): + return SANet(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/SwinFuse_layer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/SwinFuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3dca8ba793d92961bc3f1d987e211fe542b303 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/SwinFuse_layer.py @@ -0,0 +1,1310 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + # print("x shape:", x.shape) + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + # print("num heads:", self.num_heads) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + # print("x_windows:", x_windows.shape) + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + + # # 不使用残差连接 + # x = self.residual_group_A(x, x_size) + # y = self.residual_group_B(y, x_size) + + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion_layer(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Fusion_num_heads=[6, 6], + window_size=7, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, img_range=1., resi_connection='1conv', + **kwargs): + super(SwinFusion_layer, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + + ### fusion layer. + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + # print("fusion layers:", len(self.layers_Fusion)) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + + def forward_features_Fusion(self, x, y): + + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + # x = torch.cat([x, y], 1) + # ## Downsample the feature in the channel dimension + # x = self.lrelu(self.conv_after_body_Fusion(x)) + # print("x range after fusion:", x.max(), x.min()) + # print("y range after fusion:", y.max(), y.min()) + # print("x:", x.shape) + + return x, y + + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + print("modality1 input:", x.max(), x.min()) + print("modality2 input:", y.max(), y.min()) + + # multi-modal feature fusion. + x, y = self.forward_features_Fusion(x, y) ### 返回两个模态融合之后的features. + print("modality1 output:", x.max(), x.min()) + print("modality2 output:", y.max(), y.min()) + + return x[:, :, :H, :W], y[:, :, :H, :W] + + + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + + +if __name__ == '__main__': + width, height = 120, 120 + swinfuse = SwinFusion_layer(img_size=(width, height), window_size=8, depths=[6, 6, 6], embed_dim=60, num_heads=[6, 6, 6]) + # print("SwinFusion layer:", swinfuse) + + A = torch.randn((4, 60, width, height)) + B = torch.randn((4, 60, width, height)) + + x = swinfuse(A, B) + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/SwinFusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/SwinFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..3e72c2a3e8859423f3544043e8b9feaec002d0e2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/SwinFusion.py @@ -0,0 +1,1468 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # print("qkv after reshape:", qkv.shape) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + # 不使用残差连接 + # print("input x shape:", x.shape) + x = self.residual_group_A(x, x_size) + y = self.residual_group_B(y, x_size) + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Re_depths=[4], + Ex_num_heads=[6], Fusion_num_heads=[6, 6], Re_num_heads=[6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinFusion, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + print('in_chans: ', in_chans) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + ####修改shallow feature extraction 网络, 修改为2个3x3的卷积#### + self.conv_first1_A = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first1_B = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first2_A = nn.Conv2d(embed_dim_temp, embed_dim, 3, 1, 1) + self.conv_first2_B = nn.Conv2d(embed_dim_temp, embed_dim_temp, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + self.Re_num_layers = len(Re_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + dpr_Re = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Re_depths))] # stochastic depth decay rule + # build Residual Swin Transformer blocks (RSTB) + self.layers_Ex_A = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_A.append(layer) + self.norm_Ex_A = norm_layer(self.num_features) + + self.layers_Ex_B = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_B.append(layer) + self.norm_Ex_B = norm_layer(self.num_features) + + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + self.layers_Re = nn.ModuleList() + for i_layer in range(self.Re_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Re_depths[i_layer], + num_heads=Re_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Re[sum(Re_depths[:i_layer]):sum(Re_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Re.append(layer) + self.norm_Re = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last1 = nn.Conv2d(embed_dim, embed_dim_temp, 3, 1, 1) + self.conv_last2 = nn.Conv2d(embed_dim_temp, int(embed_dim_temp/2), 3, 1, 1) + self.conv_last3 = nn.Conv2d(int(embed_dim_temp/2), num_out_ch, 3, 1, 1) + + self.tanh = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features_Ex_A(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_A: + x = layer(x, x_size) + + x = self.norm_Ex_A(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward_features_Ex_B(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_B: + x = layer(x, x_size) + + x = self.norm_Ex_B(x) # B L C + x = self.patch_unembed(x, x_size) + return x + + def forward_features_Fusion(self, x, y): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + # print("x token:", x.shape, "y token:", y.shape) + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + x = torch.cat([x, y], 1) + ## Downsample the feature in the channel dimension + x = self.lrelu(self.conv_after_body_Fusion(x)) + + return x + + def forward_features_Re(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Re: + x = layer(x, x_size) + + x = self.norm_Re(x) # B L C + x = self.patch_unembed(x, x_size) + # Convolution + x = self.lrelu(self.conv_last1(x)) + x = self.lrelu(self.conv_last2(x)) + x = self.conv_last3(x) + return x + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + # self.mean_A = self.mean.type_as(x) + # self.mean_B = self.mean.type_as(y) + # self.mean = (self.mean_A + self.mean_B) / 2 + + # x = (x - self.mean_A) * self.img_range + # y = (y - self.mean_B) * self.img_range + + # Feedforward + x = self.forward_features_Ex_A(x) + y = self.forward_features_Ex_B(y) + # print("x before fusion:", x.shape) + # print("y before fusion:", y.shape) + x = self.forward_features_Fusion(x, y) + x = self.forward_features_Re(x) + x = self.tanh(x) + + # x = x / self.img_range + self.mean + # print("H:", H, "upscale:", self.upscale) + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='convs') + + +if __name__ == '__main__': + model = SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='convs') + # print(model) + + A = torch.randn((1, 1, 240, 240)) + B = torch.randn((1, 1, 240, 240)) + + x = model(A, B) + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/TransFuse.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/TransFuse.py new file mode 100644 index 0000000000000000000000000000000000000000..56ca3d2b24f382603c876e4f6909363a9794ffa3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/TransFuse.py @@ -0,0 +1,155 @@ +import math +from collections import deque + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from torchvision import models + + +class SelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + """ + + def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(n_embd, n_embd) + self.query = nn.Linear(n_embd, n_embd) + self.value = nn.Linear(n_embd, n_embd) + # regularization + self.attn_drop = nn.Dropout(attn_pdrop) + self.resid_drop = nn.Dropout(resid_pdrop) + # output projection + self.proj = nn.Linear(n_embd, n_embd) + self.n_head = n_head + + def forward(self, x): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y + + +class Block(nn.Module): + """ an unassuming Transformer block """ + + def __init__(self, n_embd, n_head, block_exp, attn_pdrop, resid_pdrop): + super().__init__() + self.ln1 = nn.LayerNorm(n_embd) + self.ln2 = nn.LayerNorm(n_embd) + self.attn = SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + self.mlp = nn.Sequential( + nn.Linear(n_embd, block_exp * n_embd), + nn.ReLU(True), # changed from GELU + nn.Linear(block_exp * n_embd, n_embd), + nn.Dropout(resid_pdrop), + ) + + def forward(self, x): + B, T, C = x.size() + + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + + return x + + +class TransFuse_layer(nn.Module): + """ the full GPT language model, with a context size of block_size """ + + def __init__(self, n_embd, n_head, block_exp, n_layer, + num_anchors, seq_len=1, + embd_pdrop=0.1, attn_pdrop=0.1, resid_pdrop=0.1): + super().__init__() + self.n_embd = n_embd + self.seq_len = seq_len + self.vert_anchors = num_anchors + self.horz_anchors = num_anchors + + # positional embedding parameter (learnable), image + lidar + self.pos_emb = nn.Parameter(torch.zeros(1, 2 * seq_len * self.vert_anchors * self.horz_anchors, n_embd)) + + self.drop = nn.Dropout(embd_pdrop) + + # transformer + self.blocks = nn.Sequential(*[Block(n_embd, n_head, + block_exp, attn_pdrop, resid_pdrop) + for layer in range(n_layer)]) + + # decoder head + self.ln_f = nn.LayerNorm(n_embd) + + self.block_size = seq_len + self.apply(self._init_weights) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + + def forward(self, m1, m2): + """ + Args: + m1 (tensor): B*seq_len, C, H, W + m2 (tensor): B*seq_len, C, H, W + """ + + bz = m2.shape[0] // self.seq_len + h, w = m2.shape[2:4] + + # forward the image model for token embeddings + m1 = m1.view(bz, self.seq_len, -1, h, w) + m2 = m2.view(bz, self.seq_len, -1, h, w) + + # pad token embeddings along number of tokens dimension + token_embeddings = torch.cat([m1, m2], dim=1).permute(0,1,3,4,2).contiguous() + token_embeddings = token_embeddings.view(bz, -1, self.n_embd) # (B, an * T, C) + + # add (learnable) positional embedding for all tokens + x = self.drop(self.pos_emb + token_embeddings) # (B, an * T, C) + x = self.blocks(x) # (B, an * T, C) + x = self.ln_f(x) # (B, an * T, C) + x = x.view(bz, 2 * self.seq_len, self.vert_anchors, self.horz_anchors, self.n_embd) + x = x.permute(0,1,4,2,3).contiguous() # same as token_embeddings + + m1_out = x[:, :self.seq_len, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + m2_out = x[:, self.seq_len:, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + print("modality1 output:", m1_out.max(), m1_out.min()) + print("modality2 output:", m2_out.max(), m2_out.min()) + + return m1_out, m2_out + + +if __name__ == "__main__": + + feature1 = torch.randn((4, 512, 20, 20)) + feature2 = torch.randn((4, 512, 20, 20)) + model = TransFuse_layer(n_embd=512, n_head=4, block_exp=4, n_layer=8, num_anchors=20, seq_len=1) + print("TransFuse_layer:", model) + feat1, feat2 = model(feature1, feature2) + print(feat1.shape, feat2.shape) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/Unet_ART.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/Unet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b0f834270a921ae4eb428c92b3e782fbed19b3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/Unet_ART.py @@ -0,0 +1,243 @@ +""" +2023/07/06, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + output = conv_layer(output) + + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/__init__.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..334f1821a9b59e692641f70cdc61e7655b1a9c89 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/__init__.py @@ -0,0 +1,114 @@ +from .unet import build_model as UNET +from .mmunet import build_model as MUNET +from .humus_net import build_model as HUMUS +from .MINet import build_model as MINET +from .SANet import build_model as SANET +from .mtrans_net import build_model as MTRANS +from .unimodal_transformer import build_model as TRANS +from .cnn_transformer_new import build_model as CNN_TRANS +from .cnn import build_model as CNN +from .unet_transformer import build_model as UNET_TRANSFORMER +from .swin_transformer import build_model as SWIN_TRANS +from .restormer import build_model as RESTORMER + + +from .cnn_late_fusion import build_model as MULTI_CNN +from .mmunet_late_fusion import build_model as MMUNET_LATE +from .mmunet_early_fusion import build_model as MMUNET_EARLY +from .trans_unet.trans_unet import build_model as TRANSUNET +from .SwinFusion import build_model as SWINMULTI +from .mUnet_transformer import build_model as MUNET_TRANSFORMER +from .munet_multi_transfuse import build_model as MUNET_TRANSFUSE +from .munet_swinfusion import build_model as MUNET_SWINFUSE +from .munet_multi_concat import build_model as MUNET_CONCAT +from .munet_concat_decomp import build_model as MUNET_CONCAT_DECOMP +from .munet_multi_sum import build_model as MUNET_SUM +# from .mUnet_ARTfusion_new import build_model as MUNET_ART_FUSION +from .mUnet_ARTfusion import build_model as MUNET_ART_FUSION +from .mUnet_Restormer_fusion import build_model as MUNET_RESTORMER_FUSION +from .mUnet_ARTfusion_SeqConcat import build_model as MUNET_ART_FUSION_SEQ + +from .ARTUnet import build_model as ART +from .Unet_ART import build_model as UNET_ART +from .unet_restormer import build_model as UNET_RESTORMER +from .mmunet_ART import build_model as MMUNET_ART +from .mmunet_restormer import build_model as MMUNET_RESTORMER +from .mmunet_restormer_v2 import build_model as MMUNET_RESTORMER_V2 +from .mmunet_restormer_3blocks import build_model as MMUNET_RESTORMER_SMALL + +from .mARTUnet import build_model as ART_MULTI_INPUT + +from .ART_Restormer import build_model as ART_RESTORMER +from .ART_Restormer_v2 import build_model as ART_RESTORMER_V2 +from .swinIR import build_model as SWINIR + +from .Kspace_ConvNet import build_model as KSPACE_CONVNET +from .Kspace_mUnet_new import build_model as KSPACE_MUNET +from .kspace_mUnet_concat import build_model as KSPACE_MUNET_CONCAT +from .kspace_mUnet_AttnFusion import build_model as KSPACE_MUNET_ATTNFUSE + +from .DuDo_mUnet import build_model as DUDO_MUNET +from .DuDo_mUnet_ARTfusion import build_model as DUDO_MUNET_ARTFUSION +from .DuDo_mUnet_CatFusion import build_model as DUDO_MUNET_CONCAT + + +model_factory = { + 'unet_single': UNET, + 'humus_single': HUMUS, + 'transformer_single': TRANS, + 'cnn_transformer': CNN_TRANS, + 'cnn_single': CNN, + 'swin_trans_single': SWIN_TRANS, + 'trans_unet': TRANSUNET, + 'unet_transformer': UNET_TRANSFORMER, + 'restormer': RESTORMER, + 'unet_art': UNET_ART, + 'unet_restormer': UNET_RESTORMER, + 'art': ART, + 'art_restormer': ART_RESTORMER, + 'art_restormer_v2': ART_RESTORMER_V2, + 'swinIR': SWINIR, + + + 'munet_transformer': MUNET_TRANSFORMER, + 'munet_transfuse': MUNET_TRANSFUSE, + 'cnn_late_multi': MULTI_CNN, + 'unet_multi': MUNET, + 'unet_late_multi':MMUNET_LATE, + 'unet_early_multi':MMUNET_EARLY, + 'munet_ARTfusion': MUNET_ART_FUSION, + 'munet_restormer_fusion': MUNET_RESTORMER_FUSION, + + + 'minet_multi': MINET, + 'sanet_multi': SANET, + 'mtrans_multi': MTRANS, + 'swin_fusion': SWINMULTI, + 'munet_swinfuse': MUNET_SWINFUSE, + 'munet_concat': MUNET_CONCAT, + 'munet_concat_decomp': MUNET_CONCAT_DECOMP, + 'munet_sum': MUNET_SUM, + 'mmunet_art': MMUNET_ART, + 'mmunet_restormer': MMUNET_RESTORMER, + 'mmunet_restormer_small': MMUNET_RESTORMER_SMALL, + 'mmunet_restormer_v2': MMUNET_RESTORMER_V2, + + 'art_multi_input': ART_MULTI_INPUT, + + 'munet_ARTfusion_SeqConcat': MUNET_ART_FUSION_SEQ, + + 'kspace_ConvNet': KSPACE_CONVNET, + 'kspace_munet': KSPACE_MUNET, + 'kspace_munet_concat': KSPACE_MUNET_CONCAT, + 'kspace_munet_AttnFusion': KSPACE_MUNET_ATTNFUSE, + + 'dudo_munet': DUDO_MUNET, + 'dudo_munet_ARTfusion': DUDO_MUNET_ARTFUSION, + 'dudo_munet_concat': DUDO_MUNET_CONCAT, +} + + +def build_model_from_name(args): + assert args.model_name in model_factory.keys(), 'unknown model name' + + return model_factory[args.model_name](args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/cnn.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5c130178e7cdabaaef8686da0cec341265f3c43d --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/cnn.py @@ -0,0 +1,246 @@ +""" +把our method中的单模态CNN_Transformer中的transformer换成几个conv_block用来下采样提取特征, 看是否能够达到和Unet相似的效果。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Reconstructor(nn.Module): + def __init__(self): + super(CNN_Reconstructor, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + nlatent = self.encoder.output_dim + self.decoder = Decoder(n_upsample=4, n_res=1, dim=nlatent, output_dim=1, pad_type='zero') + + + def forward(self, m): + #4,1,240,240 + feature_output = self.encoder.model(m) # [4, 256, 60, 60] + # print("output of the encoder:", feature_output.shape) + + output = self.decoder(feature_output) + + return output + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Reconstructor() diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/cnn_late_fusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/cnn_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..58dacad6910b6e896d00a0d36c9ea60e8d77217f --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/cnn_late_fusion.py @@ -0,0 +1,196 @@ +""" +在lequan的代码上,去掉transformer-based feature fusion部分。 +直接用CNN提取特征之后concat作为multi-modal representation. 用普通的decoder进行图像重建。 +把T1作为guidance modality, T2作为target modality. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers1 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + self.down_sample_layers2 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers1.append(ConvBlock(ch, ch * 2, drop_prob)) + self.down_sample_layers2.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + + # self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # ch = chans + # for _ in range(num_pool_layers - 1): + # self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + # ch *= 2 + # self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv = nn.Sequential(nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), nn.Tanh()) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = torch.cat([image, aux_image], 1) + # for layer in self.down_sample_layers: + # output = layer(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + + # apply down-sampling layers + output = image + for layer in self.down_sample_layers1: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + output_m1 = output + + output = aux_image + for layer in self.down_sample_layers2: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + output_m2 = output + + # print("two modality outputs:", output_m1.shape, output.shape) + output = torch.cat([output_m1, output_m2], 1) + + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv in self.up_transpose_conv: + output = transpose_conv(output) + + output = self.up_conv(output) + # print("model output:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/cnn_transformer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/cnn_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..261d4d214e35c3a9be5d28bc426ce7280124723c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/cnn_transformer.py @@ -0,0 +1,289 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=2, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=2, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 60 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 4 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + feature_output = self.upsampling1(feature_output) # [4,512,30,30] + feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/cnn_transformer_new.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/cnn_transformer_new.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b2dea1e7cc8d456a684b6c839f052383999e84 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/cnn_transformer_new.py @@ -0,0 +1,290 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +2023/05/23: 之前是从尺寸为60*60的特征图上切patch, 现在可以尝试直接用conv_blocks提取15*15的特征图,然后按照patch_size = 1切成225个patches。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=4, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 1 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + # feature_output = self.upsampling1(feature_output) # [4,512,30,30] + # feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/humus_net.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/humus_net.py new file mode 100644 index 0000000000000000000000000000000000000000..d23701634f849b414866d71b3c23c6d07f3d6437 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/humus_net.py @@ -0,0 +1,1070 @@ +""" +HUMUS-Block +Hybrid Transformer-convolutional Multi-scale denoiser. + +We use parts of the code from +SwinIR (https://github.com/JingyunLiang/SwinIR/blob/main/models/network_swinir.py) +Swin-Unet (https://github.com/HuCaoFighting/Swin-Unet/blob/main/networks/swin_transformer_unet_skip_expand_decoder_sys.py) + +HUMUS-Net中是对kspace data操作, 多个unrolling cascade, 每个cascade中都利用HUMUS-block对图像进行denoising. +因为我们的方法输入是Under-sampled images, 所以不需要进行unrolling, 直接使用一个HUMUS-block进行MRI重建。 +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from einops import rearrange + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # print("x shape:", x.shape) + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + +class PatchExpandSkip(nn.Module): + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.expand = nn.Linear(dim, 2*dim, bias=False) + self.channel_mix = nn.Linear(dim, dim // 2, bias=False) + self.norm = norm_layer(dim // 2) + + def forward(self, x, skip): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) + x = x.view(B,-1,C//4) + x = torch.cat([x, skip], dim=-1) + x = self.channel_mix(x) + x = self.norm(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, torch.tensor(x_size)) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + block_type: + D: downsampling block, + U: upsampling block + B: bottleneck block + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv', block_type='B'): + super(RSTB, self).__init__() + self.dim = dim + conv_dim = dim // (patch_size ** 2) + divide_out_ch = 1 + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + ) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(conv_dim, conv_dim // divide_out_ch, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(conv_dim, conv_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // divide_out_ch, 3, 1, 1)) + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.block_type = block_type + if block_type == 'B': + self.reshape = torch.nn.Identity() + elif block_type == 'D': + self.reshape = PatchMerging(input_resolution, dim) + elif block_type == 'U': + self.reshape = PatchExpandSkip([res // 2 for res in input_resolution], dim * 2) + else: + raise ValueError('Unknown RSTB block type.') + + def forward(self, x, x_size, skip=None): + if self.block_type == 'U': + assert skip is not None, "Skip connection is required for patch expand" + x = self.reshape(x, skip) + + out = self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size)) + block_out = self.patch_embed(out) + x + + if self.block_type == 'D': + block_out = (self.reshape(block_out), block_out) # return skip connection + + return block_out + +class PatchEmbed(nn.Module): + r""" Fixed Image to Patch Embedding. Flattens image along spatial dimensions + without learned projection mapping. Only supports 1x1 patches. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + embed_dim (int): Number of output channels. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + +class PatchEmbedLearned(nn.Module): + r""" Image to Patch Embedding with arbitrary patch size and learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Linear(in_chans * patch_size[0] * patch_size[1], embed_dim) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.patch_to_vec(x) + x = self.proj(x) + if self.norm is not None: + x = self.norm(x) + return x + + def patch_to_vec(self, x): + """ + Args: + x: (B, C, H, W) + Returns: + patches: (B, num_patches, patch_size * patch_size * C) + """ + B, C, H, W = x.shape + x = x.permute(0, 2, 3, 1).contiguous() + x = x.view(B, H // self.patch_size[0], self.patch_size[0], W // self.patch_size[1], self.patch_size[1], C) + patches = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.patch_size[0], self.patch_size[1], C) + patches = patches.contiguous().view(B, self.num_patches, self.patch_size[0] * self.patch_size[1] * C) + return patches + + +class PatchUnEmbed(nn.Module): + r""" Fixed Image to Patch Unembedding via reshaping. Only supports patch size 1x1. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + +class PatchUnEmbedLearned(nn.Module): + r""" Patch Embedding to Image with learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans=None, out_chans=None, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.out_chans = out_chans + + self.proj_up = nn.Linear(in_chans, in_chans * patch_size[0] * patch_size[1]) + + def forward(self, x, x_size=None): + B, HW, C = x.shape + x = self.proj_up(x) + x = self.vec_to_patch(x) + return x + + def vec_to_patch(self, x): + B, HW, C = x.shape + x = x.view(B, self.patches_resolution[0], self.patches_resolution[1], C).contiguous() + x = x.view(B, self.patches_resolution[0], self.patch_size[0], self.patches_resolution[1], self.patch_size[1], C // (self.patch_size[0] * self.patch_size[1])).contiguous() + x = x.view(B, self.img_size[0], self.img_size[1], -1).contiguous().permute(0, 3, 1, 2) + return x + + +class HUMUSNet(nn.Module): + r""" HUMUS-Block + Args: + img_size (int | tuple(int)): Input image size. + patch_size (int | tuple(int)): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depths (tuple(int)): Depth of each Swin Transformer layer in encoder and decoder paths. + num_heads (tuple(int)): Number of attention heads in different layers of encoder and decoder. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + ape (bool): If True, add absolute position embedding to the patch embedding. + patch_norm (bool): If True, add normalization after patch embedding. + use_checkpoint (bool): Whether to use checkpointing to save memory. + img_range: Image range. 1. or 255. + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + conv_downsample_first: use convolutional downsampling before MUST to reduce compute load on Transformers + """ + + def __init__(self, args, + img_size=[240, 240], + in_chans=1, + patch_size=1, + embed_dim=66, + depths=[2, 2, 2], + num_heads=[3, 6, 12], + window_size=5, + mlp_ratio=2., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + img_range=1., + resi_connection='1conv', + bottleneck_depth=2, + bottleneck_heads=24, + conv_downsample_first=True, + out_chans=1, + no_residual_learning=False, + **kwargs): + super(HUMUSNet, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans if out_chans is None else out_chans + + self.center_slice_out = (out_chans == 1) + + self.img_range = img_range + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + self.conv_downsample_first = conv_downsample_first + self.no_residual_learning = no_residual_learning + + ##################################################################################################### + ################################### 1, input block ################################### + input_conv_dim = embed_dim + self.conv_first = nn.Conv2d(num_in_ch, input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** self.num_layers) + self.mlp_ratio = mlp_ratio + + # Downsample for low-res feature extraction + if self.conv_downsample_first: + img_size = [im //2 for im in img_size] + self.conv_down_block = ConvBlock(input_conv_dim // 2, input_conv_dim, 0.0) + self.conv_down = DownsampConvBlock(input_conv_dim, input_conv_dim) + + # split image into non-overlapping patches + if patch_size > 1: + self.patch_embed = PatchEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + if patch_size > 1: + self.patch_unembed = PatchUnEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, out_chans=embed_dim, norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # Build MUST + # encoder + self.layers_down = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** i_layer) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='D', + ) + self.layers_down.append(layer) + + # bottleneck + dim_scaler = (2 ** self.num_layers) + self.layer_bottleneck = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=bottleneck_depth, + num_heads=bottleneck_heads, + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='B', + ) + + # decoder + self.layers_up = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** (self.num_layers - i_layer - 1)) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[(self.num_layers-1-i_layer)], + num_heads=num_heads[(self.num_layers-1-i_layer)], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='U', + ) + self.layers_up.append(layer) + + self.norm_down = norm_layer(self.num_features) + self.norm_up = norm_layer(self.embed_dim) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(input_conv_dim, input_conv_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(input_conv_dim, input_conv_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim, 3, 1, 1)) + + # Upsample if needed + if self.conv_downsample_first: + self.conv_up_block = ConvBlock(input_conv_dim, input_conv_dim // 2, 0.0) + self.conv_up = TransposeConvBlock(input_conv_dim, input_conv_dim // 2) + + ##################################################################################################### + ################################ 3, output block ################################ + self.conv_last = nn.Conv2d(input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, num_out_ch, 3, 1, 1) + self.output_norm = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + + # divisible by window size + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + + # divisible by total downsampling ratio + # this could be done more efficiently by combining the two + total_downsamp = int(2 ** (self.num_layers - 1)) + pad_h = (total_downsamp - h % total_downsamp) % total_downsamp + pad_w = (total_downsamp - w % total_downsamp) % total_downsamp + x = F.pad(x, (0, pad_w, 0, pad_h)) + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + # encode + skip_cons = [] + for i, layer in enumerate(self.layers_down): + x, skip = layer(x, x_size=layer.input_resolution) + skip_cons.append(skip) + x = self.norm_down(x) # B L C + + # bottleneck + x = self.layer_bottleneck(x, self.layer_bottleneck.input_resolution) + + # decode + for i, layer in enumerate(self.layers_up): + x = layer(x, x_size=layer.input_resolution, skip=skip_cons[-i-1]) + x = self.norm_up(x) + x = self.patch_unembed(x, x_size) + return x + + def forward(self, x): + # print("input shape:", x.shape) + # print("input range:", x.max(), x.min()) + C, H, W = x.shape[1:] + center_slice = (C - 1) // 2 + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.conv_downsample_first: + x_first = self.conv_first(x) + x_down = self.conv_down(self.conv_down_block(x_first)) + res = self.conv_after_body(self.forward_features(x_down)) + res = self.conv_up(res) + res = torch.cat([res, x_first], dim=1) + res = self.conv_up_block(res) + + res = self.conv_last(res) + + if self.no_residual_learning: + x = res + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + res + else: + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + + if self.no_residual_learning: + x = self.conv_last(res) + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + self.conv_last(res) + + # print("output range before scaling:", x.max(), x.min()) + # print("self.mean:", self.mean) + + if self.center_slice_out: + x = x / self.img_range + self.mean[:, center_slice, ...].unsqueeze(1) + else: + x = x / self.img_range + self.mean + + x = self.output_norm(x) + + return x[:, :, :H, :W] + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + +class DownsampConvBlock(nn.Module): + """ + A Downsampling Convolutional Block that consists of one strided convolution + layer followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.Conv2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H/2, W/2)`. + """ + return self.layers(image) + + + +def build_model(args): + return HUMUSNet(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/kspace_mUnet_AttnFusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/kspace_mUnet_AttnFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..02588c24559c0604ddbbd03a14b3b32e6cff2c8d --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/kspace_mUnet_AttnFusion.py @@ -0,0 +1,300 @@ +""" +Xiaohan Xing, 2023/11/22 +两个模态分别用CNN提取多个层级的特征, 每个层级都把特征沿着宽度拼接起来,得到(h, 2w, c)的特征。 +然后学习得到(h, 2w)的attention map, 和原始的特征相乘送到后面的层。 +""" + +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.attn_layers = nn.ModuleList() + + for l in range(self.num_pool_layers): + + self.attn_layers.append(nn.Sequential(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob), + nn.Conv2d(chans*(2**l), 2, kernel_size=1, stride=1), + nn.Sigmoid())) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, attn_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.attn_layers): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat之后求attention, 然后进行modulation. + attn = attn_layer(torch.cat((output1, output2), 1)) + # print("spatial attention:", attn[:, 0, :, :].unsqueeze(1).shape) + # print("output1:", output1.shape) + output1 = output1 * (1 + attn[:, 0, :, :].unsqueeze(1)) + output2 = output2 * (1 + attn[:, 1, :, :].unsqueeze(1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/kspace_mUnet_concat.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/kspace_mUnet_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..d16bfdc83305d3e1f490e5ca4b445c6d730557ce --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/kspace_mUnet_concat.py @@ -0,0 +1,351 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Predict the low-frequency mask ############## +class Mask_Predictor(nn.Module): + def __init__(self, in_chans, out_chans, chans, drop_prob): + super(Mask_Predictor, self).__init__() + self.conv1 = ConvBlock(in_chans, chans, drop_prob) + self.conv2 = ConvBlock(chans, chans*2, drop_prob) + self.conv3 = ConvBlock(chans*2, chans*4, drop_prob) + + # self.conv4 = ConvBlock(chans*4, 1, drop_prob) + + self.FC = nn.Linear(chans*4, out_chans) + self.act = nn.Sigmoid() + + + def forward(self, x): + ### 三层卷积和down-sampling. + # print("[Mask predictor] input data range:", x.max(), x.min()) + output = F.avg_pool2d(self.conv1(x), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer1_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv2(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer2_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv3(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer3_feature range:", output.max(), output.min()) + + feature = F.adaptive_avg_pool2d(F.relu(output), (1, 1)).squeeze(-1).squeeze(-1) + # print("mask_prediction feature:", output[:, :5]) + # print("[Mask predictor] bottleneck feature range:", feature.max(), feature.min()) + + output = self.FC(feature) + # print("output before act:", output) + output = self.act(output) + # print("mask output:", output) + + return output + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.mask_predictor1 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + self.mask_predictor2 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + ### 预测做low-freq DC的mask区域. + ### 输出三个维度, 前两维是坐标,最后一维是mask内部的权重 + t1_mask = self.mask_predictor1(output1) + t2_mask = self.mask_predictor2(output2) + + # print("t1_mask and t2_mask:", t1_mask.shape, t2_mask.shape) + t1_mask_coords, t2_mask_coords = t1_mask[:, :2], t2_mask[:, :2] + t1_DC_weight, t2_DC_weight = t1_mask[:, -1], t2_mask[:, -1] + + + # ### 预测做low-freq DC的DC weight + # t1_DC_weight = self.mask_predictor1(output1) + # t2_DC_weight = self.mask_predictor2(output2) + + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, t1_mask_coords, t2_mask_coords, t1_DC_weight, t2_DC_weight ## + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mARTUnet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9607d41ca14562a94e542a1804af5aeaf15bc123 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mARTUnet.py @@ -0,0 +1,563 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +2023/07/20, Xiaohan Xing +将两个模态的图像在输入端concat, 得到num_channels = 2的input, 然后送到ART模型进行T2 modality的重建。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + # print("feature after layer_norm:", x.max(), x.min()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=2, + out_channels=1, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img, aux_img): + stack = [] + bs, _, H, W = inp_img.shape + fuse_img = torch.cat((inp_img, aux_img), 1) + inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=2, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mUnet_ARTfusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..610b1e1e38d74f695a3a19f22a6a80f3152bf713 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mUnet_ARTfusion.py @@ -0,0 +1,305 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import CS_TransformerBlock as TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + # output1, output2 = image1, image2 + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mUnet_ARTfusion_SeqConcat.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mUnet_ARTfusion_SeqConcat.py new file mode 100644 index 0000000000000000000000000000000000000000..55280925415f1db47cf746b59fb9710d47c9bff2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mUnet_ARTfusion_SeqConcat.py @@ -0,0 +1,374 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any + +import matplotlib.pyplot as plt +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet_new import ConcatTransformerBlock, TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion_SeqConcat(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4. + ): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.vis_feature_maps = False + self.vis_attention_maps = False + # self.vis_feature_maps = args.MODEL.VIS_FEAT_MAPS + # self.vis_attention_maps = args.MODEL.VIS_ATTN_MAPS + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + if l < 2: + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * (2 ** (l + 1))), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + else: + self.fuse_layers.append(nn.ModuleList([ConcatTransformerBlock(dim=int(chans * (2 ** l)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + feature_maps = { + 'modal1': {'before': [], 'after': []}, + 'modal2': {'before': [], 'after': []} + } + + attention_maps = [] + """ + attention_maps = [ + unet layer 1 attention maps (x2): [map from TransBlock1, TransBlock2, ...], + unet layer 2 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 3 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 4 attention maps (x6): [map from TransBlock1, TransBlock2, ...], + ] + """ + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + if self.vis_feature_maps: + feature_maps['modal1']['before'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['before'].append(output2.detach().cpu().numpy()) + + ### 两个模态的特征concat + # output = torch.cat((output1, output2), 1) + + + + unet_layer_attn_maps = [] + """ + unet_layer_attn_maps: [ + attn_map from TransformerBlock1: (B, nHGroups, nWGroups, winSize, winSize), + attn_map from TransformerBlock2: (B, nHGroups, nWGroups, winSize, winSize), + ... + ] + """ + if l < 2: + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + else: + output1 = rearrange(output1, "b c h w -> b (h w) c").contiguous() + output2 = rearrange(output2, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + if l < 2: + if self.vis_attention_maps: + output, attn_map = layer(output, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output = layer(output, [H // (2**l), W // (2**l)]) + + else: + if self.vis_attention_maps: + output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + attention_maps.append(unet_layer_attn_maps) + + if l < 2: + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + else: + output1 = rearrange(output1, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output2 = rearrange(output2, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + # if self.vis_attention_maps: + # output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) # attn_map: (B, nHGroups, nWGroups, winSize, winSize) + # unet_layer_attn_maps.append(attn_map) + # else: + # output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + # attention_maps.append(unet_layer_attn_maps) + + + + if self.vis_feature_maps: + feature_maps['modal1']['after'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['after'].append(output2.detach().cpu().numpy()) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + if self.vis_feature_maps: + return output1, output2, feature_maps#, relation_stack1, relation_stack2 #, t1_features, t2_features + elif self.vis_attention_maps: + return output1, output2, attention_maps + else: + return output1, output2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion_SeqConcat(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mUnet_Restormer_fusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mUnet_Restormer_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8ab29b641fa0231ad0d163224a4a90b44c32a9 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mUnet_Restormer_fusion.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .restormer import TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8]): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.Sequential(*[TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[0], ffn_expansion_factor=2.66, \ + bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + ### 将encoder中multi-modal fusion之后的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = fuse_layer(output) + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mUnet_transformer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mUnet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2a209758203b198502154b75496333ef91717a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mUnet_transformer.py @@ -0,0 +1,387 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/06/08 +分别用CNN提取两个模态的特征,然后送到transformer进行特征交互, 将经过transformer提取的特征和之前CNN提取的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Decoder_wo_skip(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder_wo_skip, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + output = transpose_conv(output) + output = conv(output) + + return output + + + + + +class mUnet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch*2 + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1, output2 = image1, image2 + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack1, output1 = self.encoder1(output1) ### output size: [4, 256, 15, 15] + stack2, output2 = self.encoder2(output2) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output1) # [4, 225, 256], 225 is the number of patches. + patch_embed_input1 = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + n_embed = self.to_patch_embedding(output2) # [4, 225, 256], 225 is the number of patches. + patch_embed_input2 = F.normalize(n_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input1, patch_embed_input2), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1]/2)), int(np.sqrt(feature_output.shape[1]/2)) + patch_num = feature_output.shape[1] + feature_output1 = feature_output[:, :patch_num//2, :] + feature_output2 = feature_output[:, patch_num//2:, :] + feature_output1 = feature_output1.contiguous().view(b, feature_output1.shape[-1], h, w) # [4,c,15,15] + feature_output2 = feature_output2.contiguous().view(b, feature_output2.shape[-1], h, w) + + output = torch.cat((output1, output2), 1) + torch.cat((feature_output1, feature_output2), 1) + # print("CNN feature range:", output1.max(), output1.min()) + # print("Transformer feature range:", feature_output1.max(), feature_output1.min()) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_Transformer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ee0c4784d04638df4ca259613777dde34980a2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet.py @@ -0,0 +1,201 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/04 +Concatenate the input images from different modalities and input it into the UNet. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_ART.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..1a09f2900a8924cdc1a7c373270e6ff8aadcfa45 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_ART.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_ART_v2.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_ART_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..adfe6bc0d9192a126f39932b40443badb5b60068 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_ART_v2.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_early_fusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_early_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..d062f057a2080a471005125a00ee27453a4b0706 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_early_fusion.py @@ -0,0 +1,223 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +encapsulate the encoder and decoder for the Unet model. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + stack, output = self.encoder(output) + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + output = self.decoder(output, stack) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_late_fusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..2b48e7d4445afd6a5bc1eaa2b065c37431585e94 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_late_fusion.py @@ -0,0 +1,229 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +use Unet for the reconstruction of each modality, concat the intermediate features of both modalities. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("decoder output:", output.shape) + + return output + + + + +class mmUnet_late(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack1, output1 = self.encoder1(image) + stack2, output2 = self.encoder2(aux_image) + ### fuse two modalities to get multi-modal representation. + output = torch.cat([output1, output2], 1) + output = self.conv(output) ### from [4, 512, 15, 15] to [4, 512, 15, 15] + + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet_late(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_restormer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..88c7920d65c58ab82a00309a1ae501a8d5b33804 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_restormer.py @@ -0,0 +1,237 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_restormer_3blocks.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_restormer_3blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e3afe68612727d3195872a121ec8040fb0f66ef2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_restormer_3blocks.py @@ -0,0 +1,235 @@ +""" +2023/07/18, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +之前 3 blocks的模型深层特征存在严重的gradient vanishing问题。现在把模型的encoder和decoder都改成只有3个blocks, 看是否还存在gradient vanishing问题。 +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 3, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2] + heads=[1, 2, 4] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_restormer_v2.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..4b51d3a218057206863230973d557173e891d3cd --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mmunet_restormer_v2.py @@ -0,0 +1,220 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +这个版本中是将Unet encoder提取的特征直接连接到decoder, 经过restormer refine的特征只是传递到下一层的encoder block. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + output = feature + + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mtrans_mca.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mtrans_mca.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4c89df4af71e986fa7c24d522e6732371f3128 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mtrans_mca.py @@ -0,0 +1,119 @@ +import torch +from torch import nn + +from .mtrans_transformer import Mlp + +# implement by YunluYan + + +class LinearProject(nn.Module): + def __init__(self, dim_in, dim_out, fn): + super(LinearProject, self).__init__() + need_projection = dim_in != dim_out + self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity() + self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity() + self.fn = fn + + def forward(self, x, *args, **kwargs): + x = self.project_in(x) + x = self.fn(x, *args, **kwargs) + x = self.project_out(x) + return x + + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, dim, num_heads, attn_drop=0., proj_drop=0.): + super(MultiHeadCrossAttention, self).__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim, dim*2, bias=False) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + + def forward(self, x, complement): + + # x [B, HW, C] + B_x, N_x, C_x = x.shape + + x_copy = x + + complement = torch.cat([x, complement], 1) + + B_c, N_c, C_c = complement.shape + + # q [B, heads, HW, C//num_heads] + q = self.to_q(x).reshape(B_x, N_x, self.num_heads, C_x//self.num_heads).permute(0, 2, 1, 3) + kv = self.to_kv(complement).reshape(B_c, N_c, 2, self.num_heads, C_c//self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_x, N_x, C_x) + + x = x + x_copy + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossTransformerEncoderLayer(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio = 1., attn_drop=0., proj_drop=0.,drop_path = 0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super(CrossTransformerEncoderLayer, self).__init__() + self.x_norm1 = norm_layer(dim) + self.c_norm1 = norm_layer(dim) + + self.attn = MultiHeadCrossAttention(dim, num_heads, attn_drop, proj_drop) + + self.x_norm2 = norm_layer(dim) + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop) + + self.drop1 = nn.Dropout(drop_path) + self.drop2 = nn.Dropout(drop_path) + + def forward(self, x, complement): + x = self.x_norm1(x) + complement = self.c_norm1(complement) + + x = x + self.drop1(self.attn(x, complement)) + x = x + self.drop2(self.mlp(self.x_norm2(x))) + return x + + + +class CrossTransformer(nn.Module): + def __init__(self, x_dim, c_dim, depth, num_heads, mlp_ratio =1., attn_drop=0., proj_drop=0., drop_path =0.): + super(CrossTransformer, self).__init__() + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + LinearProject(x_dim, c_dim, CrossTransformerEncoderLayer(c_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)), + LinearProject(c_dim, x_dim, CrossTransformerEncoderLayer(x_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)) + ])) + + def forward(self, x, complement): + for x_attn_complemnt, complement_attn_x in self.layers: + x = x_attn_complemnt(x, complement=complement) + x + complement = complement_attn_x(complement, complement=x) + complement + return x, complement + + + + + + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mtrans_net.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mtrans_net.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e80dadb45725603c7ff1da664a45533575db25 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mtrans_net.py @@ -0,0 +1,145 @@ +""" +This script implement the paper "Multi-Modal Transformer for Accelerated MR Imaging (TMI 2022)" +""" + + +import torch + +from torch import nn +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler +from .mtrans_mca import CrossTransformer + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(INPUT_DIM, HEAD_HIDDEN_DIM): + return ReconstructionHead(INPUT_DIM, HEAD_HIDDEN_DIM) + + +class CrossCMMT(nn.Module): + + def __init__(self, args): + super(CrossCMMT, self).__init__() + + INPUT_SIZE = 240 + INPUT_DIM = 1 # the channel of input + OUTPUT_DIM = 1 # the channel of output + HEAD_HIDDEN_DIM = 16 # the hidden dim of Head + TRANSFORMER_DEPTH = 4 # the depth of the transformer + TRANSFORMER_NUM_HEADS = 4 # the head's num of multi head attention + TRANSFORMER_MLP_RATIO = 3 # the MLP RATIO Of transformer + TRANSFORMER_EMBED_DIM = 256 # the EMBED DIM of transformer + P1 = 8 + P2 = 16 + CTDEPTH = 4 + + + self.head = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + self.head2 = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + + x_patch_dim = HEAD_HIDDEN_DIM * P1 ** 2 + x_num_patches = (INPUT_SIZE // P1) ** 2 + + complement_patch_dim = HEAD_HIDDEN_DIM * P2 ** 2 + complement_num_patches = (INPUT_SIZE // P2) ** 2 + + self.x_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P1, + p2=P1), + ) + + self.complement_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P2, + p2=P2), + ) + + self.x_pos_embedding = nn.Parameter(torch.randn(1, x_num_patches, x_patch_dim)) + self.complement_pos_embedding = nn.Parameter(torch.randn(1, complement_num_patches, complement_patch_dim)) + + ### + self.cross_transformer = CrossTransformer(x_patch_dim, complement_patch_dim, CTDEPTH, + TRANSFORMER_NUM_HEADS, TRANSFORMER_MLP_RATIO) + + self.p1 = P1 + self.p2 = P2 + + self.tail1 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + self.tail2 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x, complement): + x = self.head(x) + complement = self.head2(complement) + + x = x + complement + + b, _, h, w = x.shape + + _, _, c_h, c_w = complement.shape + + ### 得到每个模态的patch embedding (加上了position encoding) + x = self.x_patch_embbeding(x) + x += self.x_pos_embedding + + complement = self.complement_patch_embbeding(complement) + complement += self.complement_pos_embedding + + ### 将两个模态的patch embeddings送到cross transformer中提取特征。 + x, complement = self.cross_transformer(x, complement) + + c = int(x.shape[2] / (self.p1 * self.p1)) + H = int(h / self.p1) + W = int(w / self.p1) + + x = x.reshape(b, H, W, self.p1, self.p1, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w, ) + x = self.tail1(x) + + complement = complement.reshape(b, int(c_h/self.p2), int(c_w/self.p2), self.p2, self.p2, int(complement.shape[2]/self.p2/self.p2)) + complement = complement.permute(0, 5, 1, 3, 2, 4) + complement = complement.reshape(b, -1, c_h, c_w) + + complement = self.tail2(complement) + + return x, complement + + +def build_model(args): + return CrossCMMT(args) + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mtrans_transformer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mtrans_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..61314a6a81247b0740a1e98d078242b26ce73f90 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/mtrans_transformer.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(args, embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=args.MODEL.TRANSFORMER_DEPTH, + num_heads=args.MODEL.TRANSFORMER_NUM_HEADS, + mlp_ratio=args.MODEL.TRANSFORMER_MLP_RATIO) + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/munet_concat_decomp.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/munet_concat_decomp.py new file mode 100644 index 0000000000000000000000000000000000000000..07c6f64804dd586927814801b167163f0b65f58f --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/munet_concat_decomp.py @@ -0,0 +1,295 @@ +""" +Xiaohan Xing, 2023/06/25 +两个模态分别用CNN提取多个层级的特征, 每个层级的特征都经过concat之后得到fused feature, +然后每个模态都通过一层conv变换得到2C*h*w的特征, 其中C个channels作为common feature, 另C个channels作为specific features. +利用CDDFuse论文中的decomposition loss约束特征解耦。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.transform_layers1 = nn.ModuleList() + self.transform_layers2 = nn.ModuleList() + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.transform_layers1.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.transform_layers2.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + stack1_common, stack1_specific, stack2_common, stack2_specific = [], [], [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_transform_layer, net2_transform_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.transform_layers1, self.transform_layers2, \ + self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + # print("original feature:", output1.shape) + + ### 分别提取两个模态的common feature和specific feature + output1 = net1_transform_layer(output1) + # print("transformed feature:", output1.shape) + output1_common = output1[:, :(output1.shape[1]//2), :, :] + output1_specific = output1[:, (output1.shape[1]//2):, :, :] + output2 = net2_transform_layer(output2) + output2_common = output2[:, :(output2.shape[1]//2), :, :] + output2_specific = output2[:, (output2.shape[1]//2):, :, :] + + stack1_common.append(output1_common) + stack1_specific.append(output1_specific) + stack2_common.append(output2_common) + stack2_specific.append(output2_specific) + + ### 将一个模态的两组feature和另一模态的common feature合并, 然后经过conv变换得到和初始特征相同维度的fused feature送到下一层。 + output1 = net1_fuse_layer(torch.cat((output1_common, output1_specific, output2_common), 1)) + output2 = net2_fuse_layer(torch.cat((output2_common, output2_specific, output1_common), 1)) + # print("fused feature:", output1.shape) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, stack1_common, stack1_specific, stack2_common, stack2_specific + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/munet_multi_concat.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/munet_multi_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..9e90373e1134210809b98fdc64e3d51d76123855 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/munet_multi_concat.py @@ -0,0 +1,306 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + # output1, output2 = image1, image2 + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # relation1 = get_relation_matrix(output1) + # relation2 = get_relation_matrix(output2) + # relation_stack1.append(relation1) + # relation_stack2.append(relation2) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + # output1 = net1_layer(output1) + # output2 = net2_layer(output2) + + # ### 两个模态的特征concat + # fused_output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + # fused_output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # stack1.append(fused_output1) + # stack2.append(fused_output2) + + # ### downsampling features + # output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + # output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/munet_multi_sum.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/munet_multi_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb4efb13b54ccbeaeb34fc51d3e72b0c50ee242 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/munet_multi_sum.py @@ -0,0 +1,270 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers): + + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + + ## 两个模态的特征求平均. + fused_output1 = (output1 + output2)/2.0 + fused_output2 = (output1 + output2)/2.0 + + output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features #, relation_stack1, relation_stack2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/munet_multi_transfuse.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/munet_multi_transfuse.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ce250a078ba0847c729c1ef1b76452f64d0f07 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/munet_multi_transfuse.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +2023/06/23 +每个层级的特征用TransFuse layer融合之后, 将两个模态的original features和transfuse features concat, +然后在每个模态中经过conv层变换得到下一层的输入。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .TransFuse import TransFuse_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_TransFuse(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + ### conv layer to transform the features after Transfuse layer. + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + self.n_anchors = 15 + self.avgpool = nn.AdaptiveAvgPool2d((self.n_anchors, self.n_anchors)) + self.transfuse_layer1 = TransFuse_layer(n_embd=self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer2 = TransFuse_layer(n_embd=2*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer3 = TransFuse_layer(n_embd=4*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer4 = TransFuse_layer(n_embd=8*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + + self.transfuse_layers = nn.Sequential(self.transfuse_layer1, self.transfuse_layer2, self.transfuse_layer3, self.transfuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, transfuse_layer, net1_fuse_layer, net2_fuse_layer in zip(self.encoder1.down_sample_layers, \ + self.encoder2.down_sample_layers, self.transfuse_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### extract multi-level features from the encoder of each modality. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + ### Transformer-based multi-modal fusion layer + feature1 = self.avgpool(output1) + feature2 = self.avgpool(output2) + feature1, feature2 = transfuse_layer(feature1, feature2) + feature1 = F.interpolate(feature1, scale_factor=output1.shape[-1]//self.n_anchors, mode='bilinear') + feature2 = F.interpolate(feature2, scale_factor=output2.shape[-1]//self.n_anchors, mode='bilinear') + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_TransFuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/munet_swinfusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/munet_swinfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6203c4331744f1da70f18e403360196a419b4cb5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/munet_swinfusion.py @@ -0,0 +1,268 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .SwinFuse_layer import SwinFusion_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_SwinFuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过SwinFusion的方式融合。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.fuse_type = 'swinfuse' + self.img_size = 240 + # self.fuse_type = 'sum' + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # Swin transformer fusion modules + + self.fuse_layer1 = SwinFusion_layer(img_size=(self.img_size//2, self.img_size//2), patch_size=4, window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer2 = SwinFusion_layer(img_size=(self.img_size//4, self.img_size//4), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=2*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer3 = SwinFusion_layer(img_size=(self.img_size//8, self.img_size//8), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=4*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer4 = SwinFusion_layer(img_size=(self.img_size//16, self.img_size//16), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=8*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.swinfusion_layers = nn.Sequential(self.fuse_layer1, self.fuse_layer2, self.fuse_layer3, self.fuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, swinfuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.swinfusion_layers): + output1 = net1_layer(output1) + stack1.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # print("output of encoders:", output1.shape, output2.shape) + # print("CNN feature range:", output1.max(), output2.min()) + + ### Swin transformer-based multi-modal fusion. + feature1, feature2 = swinfuse_layer(output1, output2) + output1 = (output1 + feature1 + output2 + feature2)/4.0 + output2 = (output1 + feature1 + output2 + feature2)/4.0 + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_SwinFuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/original_MINet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/original_MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..5a115872320f677d8b34264eed95c22fff0df9c4 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/original_MINet.py @@ -0,0 +1,335 @@ +from fastmri.models import common +import torch +import torch.nn as nn +import torch.nn.functional as F + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=common.default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 2 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + common.Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.add_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class MINet(nn.Module): + def __init__(self, n_resgroups,n_resblocks, n_feats): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1,x2#x1=pd x2=pdfs + \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/restormer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..3699cc809e438be4e1bd5efaba7d96e18d7f1602 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/restormer.py @@ -0,0 +1,307 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + # print("feature before FFN projection:", x.max(), x.min()) + x = self.project_out(x) + # print("FFN output:", x.max(), x.min()) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + b,c,h,w = x.shape + # print("Attention block input:", x.max(), x.min()) + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + # print("attention values:", attn) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("Attention block output:", out.max(), out.min()) + + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): + super(TransformerBlock, self).__init__() + + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + # dim = 32, + # num_blocks = [4,4,4,4], + dim = 32, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + + # stack = [] + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + # print("encoder block1 feature:", out_enc_level1.shape, out_enc_level1.max(), out_enc_level1.min()) + # stack.append(out_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + # print("encoder block2 feature:", out_enc_level2.shape, out_enc_level2.max(), out_enc_level2.min()) + # stack.append(out_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + # print("encoder block3 feature:", out_enc_level3.shape, out_enc_level3.max(), out_enc_level3.min()) + # stack.append(out_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + # print("encoder block4 feature:", latent.shape, latent.max(), latent.min()) + # stack.append(latent) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + + return out_dec_level1 #, stack + + + +def build_model(args): + return Restormer() diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/restormer_block.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/restormer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..170ae6a826e5a3e356cb48f8c89500c2d5242816 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/restormer_block.py @@ -0,0 +1,321 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + # print("input to the LayerNorm layer:", x.max(), x.min()) + h, w = x.shape[-2:] + x = to_4d(self.body(to_3d(x)), h, w) + # print("output of the LayerNorm:", x.max(), x.min()) + return x + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + # self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.temperature = 1.0 / ((dim//num_heads) ** 0.5) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + # print("input to the Attention layer:", x.max(), x.min()) + + b,c,h,w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + # print("q range:", q.max(), q.min()) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + # print("scaling parameter:", self.temperature) + # print("attention matrix before softmax:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + attn = attn.softmax(dim=-1) + # print("attention matrix range:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + # print("v range:", v.max(), v.min(), v.mean()) + + out = (attn @ v) + # print("self-attention output:", out.max(), out.min()) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("output of attention layer:", out.max(), out.min()) + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, pre_norm): + super(TransformerBlock, self).__init__() + + self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + self.pre_norm = pre_norm + + def forward(self, x): + x = self.conv(x) + # print("restormer block input:", x.max(), x.min()) + if self.pre_norm: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + else: + x = self.norm1(x + self.attn(x)) + x = self.norm2(x + self.ffn(x)) + # x = self.norm1(self.attn(x)) + # x = self.norm2(self.ffn(x)) + # x = x + self.attn(x) + # x = x + self.ffn(x) + # print("restormer block output:", x.max(), x.min()) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + dim = 48, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, inp_img): + + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + + + return out_dec_level1 + + + + +if __name__ == '__main__': + model = Restormer() + # print(model) + + A = torch.randn((1, 1, 240, 240)) + x = model(A) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/swinIR.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/swinIR.py new file mode 100644 index 0000000000000000000000000000000000000000..5cbfc9c175c02531ac80a05ba4abfc0776f752dd --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/swinIR.py @@ -0,0 +1,874 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 4, 4, 6], + embed_dim=32, num_heads=[6, 6, 6, 6], mlp_ratio=4, upsampler='pixelshuffledirect') + + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = SwinIR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/swin_transformer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..877bdfbffc91012e4ac31f41b838ebb116f4a825 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/swin_transformer.py @@ -0,0 +1,878 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=1, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + # print("shallow feature:", x.shape) + x = self.conv_after_body(self.forward_features(x)) + x + # print("deep feature:", x.shape) + x = self.upsample(x) + # print("after upsample:", x.shape) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + # return SwinIR(upscale=1, img_size=(240, 240), + # window_size=8, img_range=1., depths=[6, 6, 6, 6], + # embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + +if __name__ == '__main__': + height = 240 + width = 240 + window_size = 8 + model = SwinIR(upscale=1, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + x = torch.randn((1, 1, height, width)) + print("input:", x.shape) + x = model(x) + print("output:", x.shape) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/trans_unet.zip b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/trans_unet.zip new file mode 100644 index 0000000000000000000000000000000000000000..e7b71c1287551fd6119efd7c62da2196307998f5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/trans_unet.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:369de2a28036532c446409ee1f7b953d5763641ffd452675613e1bb9bba89eae +size 19354 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/trans_unet/trans_unet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/trans_unet/trans_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..fa61dbadfcaee9b26594307adc90cdd1d2a84e3e --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/trans_unet/trans_unet.py @@ -0,0 +1,489 @@ +# coding=utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import logging +import math + +from os.path import join as pjoin + +import torch +import torch.nn as nn +import numpy as np + +from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm +from torch.nn.modules.utils import _pair +from scipy import ndimage +from . import vit_seg_configs as configs +# import .vit_seg_configs as configs +from .vit_seg_modeling_resnet_skip import ResNetV2 + + +logger = logging.getLogger(__name__) + + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class Attention(nn.Module): + def __init__(self, config, vis): + super(Attention, self).__init__() + self.vis = vis + self.num_attention_heads = config.transformer["num_heads"] + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(config.hidden_size, self.all_head_size) + self.key = Linear(config.hidden_size, self.all_head_size) + self.value = Linear(config.hidden_size, self.all_head_size) + + self.out = Linear(config.hidden_size, config.hidden_size) + self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) + self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + + +class Mlp(nn.Module): + def __init__(self, config): + super(Mlp, self).__init__() + self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) + self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(config.transformer["dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings. + """ + def __init__(self, config, img_size, in_channels=3): + super(Embeddings, self).__init__() + self.hybrid = None + self.config = config + img_size = _pair(img_size) + + if config.patches.get("grid") is not None: # ResNet + grid_size = config.patches["grid"] + # print("image size:", img_size, "grid size:", grid_size) + patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) + patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) + # print("patch_size_real", patch_size_real) + n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) + self.hybrid = True + else: + patch_size = _pair(config.patches["size"]) + n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) + self.hybrid = False + + if self.hybrid: + self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) + in_channels = self.hybrid_model.width * 16 + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=config.hidden_size, + kernel_size=patch_size, + stride=patch_size) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) + + self.dropout = Dropout(config.transformer["dropout_rate"]) + + + def forward(self, x): + # print("input x:", x.shape) ### (4, 3, 240, 240) + if self.hybrid: + x, features = self.hybrid_model(x) + else: + features = None + # print("output of the hybrid model:", x.shape) ### (4, 1024, 15, 15) + x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) + x = x.flatten(2) + x = x.transpose(-1, -2) # (B, n_patches, hidden) + + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings, features + + +class Block(nn.Module): + def __init__(self, config, vis): + super(Block, self).__init__() + self.hidden_size = config.hidden_size + self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn = Mlp(config) + self.attn = Attention(config, vis) + + def forward(self, x): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + def load_from(self, weights, n_block): + ROOT = f"Transformer/encoderblock_{n_block}" + with torch.no_grad(): + query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() + key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() + value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() + out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() + + query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) + key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) + value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) + out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) + + self.attn.query.weight.copy_(query_weight) + self.attn.key.weight.copy_(key_weight) + self.attn.value.weight.copy_(value_weight) + self.attn.out.weight.copy_(out_weight) + self.attn.query.bias.copy_(query_bias) + self.attn.key.bias.copy_(key_bias) + self.attn.value.bias.copy_(value_bias) + self.attn.out.bias.copy_(out_bias) + + mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() + mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() + mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() + mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() + + self.ffn.fc1.weight.copy_(mlp_weight_0) + self.ffn.fc2.weight.copy_(mlp_weight_1) + self.ffn.fc1.bias.copy_(mlp_bias_0) + self.ffn.fc2.bias.copy_(mlp_bias_1) + + self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) + self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) + self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) + self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) + + +class Encoder(nn.Module): + def __init__(self, config, vis): + super(Encoder, self).__init__() + self.vis = vis + self.layer = nn.ModuleList() + self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) + for _ in range(config.transformer["num_layers"]): + layer = Block(config, vis) + self.layer.append(copy.deepcopy(layer)) + + def forward(self, hidden_states): + attn_weights = [] + for layer_block in self.layer: + hidden_states, weights = layer_block(hidden_states) + if self.vis: + attn_weights.append(weights) + encoded = self.encoder_norm(hidden_states) + return encoded, attn_weights + + +class Transformer(nn.Module): + def __init__(self, config, img_size, vis): + super(Transformer, self).__init__() + self.embeddings = Embeddings(config, img_size=img_size) + self.encoder = Encoder(config, vis) + # print(self.encoder) + + def forward(self, input_ids): + embedding_output, features = self.embeddings(input_ids) ### "features" store the features from different layers. + # print("embedding_output:", embedding_output.shape) ## (B, n_patch, hidden) + encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) + # print("encoded feature:", encoded.shape) + return encoded, attn_weights, features + + # return embedding_output, features + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + bn = nn.BatchNorm2d(out_channels) + + super(Conv2dReLU, self).__init__(conv, bn, relu) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels=0, + use_batchnorm=True, + ): + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + + def forward(self, x, skip=None): + x = self.up(x) + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class ReconstructionHead(nn.Sequential): + + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() + super().__init__(conv2d, upsampling) + + +class DecoderCup(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + head_channels = 512 + self.conv_more = Conv2dReLU( + config.hidden_size, + head_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + decoder_channels = config.decoder_channels + in_channels = [head_channels] + list(decoder_channels[:-1]) + out_channels = decoder_channels + + if self.config.n_skip != 0: + skip_channels = self.config.skip_channels + for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip + skip_channels[3-i]=0 + + else: + skip_channels=[0,0,0,0] + + blocks = [ + DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) + ] + self.blocks = nn.ModuleList(blocks) + # print(self.blocks) + + def forward(self, hidden_states, features=None): + B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) + h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) + x = hidden_states.permute(0, 2, 1) + x = x.contiguous().view(B, hidden, h, w) + x = self.conv_more(x) + for i, decoder_block in enumerate(self.blocks): + if features is not None: + skip = features[i] if (i < self.config.n_skip) else None + else: + skip = None + x = decoder_block(x, skip=skip) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): + super(VisionTransformer, self).__init__() + self.num_classes = num_classes + self.zero_head = zero_head + self.classifier = config.classifier + self.transformer = Transformer(config, img_size, vis) + self.decoder = DecoderCup(config) + self.segmentation_head = ReconstructionHead( + in_channels=config['decoder_channels'][-1], + out_channels=config['n_classes'], + kernel_size=3, + ) + self.activation = nn.Tanh() + self.config = config + + def forward(self, x): + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + ### hybrid CNN and transformer to extract features from the input images. + x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) + + # ### extract features with CNN only. + # x, features = self.transformer(x) # (B, n_patch, hidden) + + x = self.decoder(x, features) + # logits = self.segmentation_head(x) + logits = self.activation(self.segmentation_head(x)) + # print("logits:", logits.shape) + return logits + + def load_from(self, weights): + with torch.no_grad(): + + res_weight = weights + self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) + self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) + + self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) + self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) + + posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) + + posemb_new = self.transformer.embeddings.position_embeddings + if posemb.size() == posemb_new.size(): + self.transformer.embeddings.position_embeddings.copy_(posemb) + elif posemb.size()[1]-1 == posemb_new.size()[1]: + posemb = posemb[:, 1:] + self.transformer.embeddings.position_embeddings.copy_(posemb) + else: + logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) + ntok_new = posemb_new.size(1) + if self.classifier == "seg": + _, posemb_grid = posemb[:, :1], posemb[0, 1:] + gs_old = int(np.sqrt(len(posemb_grid))) + gs_new = int(np.sqrt(ntok_new)) + print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb = posemb_grid + self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) + + # Encoder whole + for bname, block in self.transformer.encoder.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=uname) + + if self.transformer.embeddings.hybrid: + self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) + gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) + gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) + self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) + self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) + + for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(res_weight, n_block=bname, n_unit=uname) + +CONFIGS = { + 'ViT-B_16': configs.get_b16_config(), + 'ViT-B_32': configs.get_b32_config(), + 'ViT-L_16': configs.get_l16_config(), + 'ViT-L_32': configs.get_l32_config(), + 'ViT-H_14': configs.get_h14_config(), + 'R50-ViT-B_16': configs.get_r50_b16_config(), + 'R50-ViT-L_16': configs.get_r50_l16_config(), + 'testing': configs.get_testing(), +} + + + +def build_model(args): + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + print(net) + # net.load_from(weights=np.load(config_vit.pretrained_path)) + return net + + +if __name__ == "__main__": + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + net.load_from(weights=np.load(config_vit.pretrained_path)) + + input_data = torch.rand(4, 1, 240, 240).cuda() + print("input:", input_data.shape) + output = net(input_data) + print("output:", output.shape) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/trans_unet/vit_seg_configs.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/trans_unet/vit_seg_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..bed7d3346e30328a38ac6fef1621f788d905df1e --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/trans_unet/vit_seg_configs.py @@ -0,0 +1,132 @@ +import ml_collections + +def get_b16_config(): + """Returns the ViT-B/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 768 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 3072 + config.transformer.num_heads = 12 + config.transformer.num_layers = 4 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + + config.classifier = 'seg' + config.representation_size = None + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' + config.patch_size = 16 + + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_testing(): + """Returns a minimal configuration for testing.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 1 + config.transformer.num_heads = 1 + config.transformer.num_layers = 1 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config + + +def get_r50_b16_config(): + """Returns the Resnet50 + ViT-B/16 configuration.""" + config = get_b16_config() + # config.patches.grid = (16, 16) + config.patches.grid = (15, 15) ### for the input image with size (240, 240) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'recon' + config.pretrained_path = '/home/xiaohan/workspace/MSL_MRI/code/pretrained_model/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 1 + config.n_skip = 3 + config.activation = 'tanh' + + return config + + +def get_b32_config(): + """Returns the ViT-B/32 configuration.""" + config = get_b16_config() + config.patches.size = (32, 32) + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' + return config + + +def get_l16_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1024 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 4096 + config.transformer.num_heads = 16 + config.transformer.num_layers = 24 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.representation_size = None + + # custom + config.classifier = 'seg' + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_r50_l16_config(): + """Returns the Resnet50 + ViT-L/16 configuration. customized """ + config = get_l16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_l32_config(): + """Returns the ViT-L/32 configuration.""" + config = get_l16_config() + config.patches.size = (32, 32) + return config + + +def get_h14_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (14, 14)}) + config.hidden_size = 1280 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 5120 + config.transformer.num_heads = 16 + config.transformer.num_layers = 32 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + + return config \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae80ef7d96c46cb91bda849901aef34008e62ed --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py @@ -0,0 +1,160 @@ +import math + +from os.path import join as pjoin +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +class StdConv2d(nn.Conv2d): + + def forward(self, x): + w = self.weight + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / torch.sqrt(v + 1e-5) + return F.conv2d(x, w, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +def conv3x3(cin, cout, stride=1, groups=1, bias=False): + return StdConv2d(cin, cout, kernel_size=3, stride=stride, + padding=1, bias=bias, groups=groups) + + +def conv1x1(cin, cout, stride=1, bias=False): + return StdConv2d(cin, cout, kernel_size=1, stride=stride, + padding=0, bias=bias) + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + """ + + def __init__(self, cin, cout=None, cmid=None, stride=1): + super().__init__() + cout = cout or cin + cmid = cmid or cout//4 + + self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv1 = conv1x1(cin, cmid, bias=False) + self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! + self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) + self.conv3 = conv1x1(cmid, cout, bias=False) + self.relu = nn.ReLU(inplace=True) + + if (stride != 1 or cin != cout): + # Projection also with pre-activation according to paper. + self.downsample = conv1x1(cin, cout, stride, bias=False) + self.gn_proj = nn.GroupNorm(cout, cout) + + def forward(self, x): + + # Residual branch + residual = x + if hasattr(self, 'downsample'): + residual = self.downsample(x) + residual = self.gn_proj(residual) + + # Unit's branch + y = self.relu(self.gn1(self.conv1(x))) + y = self.relu(self.gn2(self.conv2(y))) + y = self.gn3(self.conv3(y)) + + y = self.relu(residual + y) + return y + + def load_from(self, weights, n_block, n_unit): + conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) + conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) + conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) + + gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) + gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) + + gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) + gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) + + gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) + gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) + + self.conv1.weight.copy_(conv1_weight) + self.conv2.weight.copy_(conv2_weight) + self.conv3.weight.copy_(conv3_weight) + + self.gn1.weight.copy_(gn1_weight.view(-1)) + self.gn1.bias.copy_(gn1_bias.view(-1)) + + self.gn2.weight.copy_(gn2_weight.view(-1)) + self.gn2.bias.copy_(gn2_bias.view(-1)) + + self.gn3.weight.copy_(gn3_weight.view(-1)) + self.gn3.bias.copy_(gn3_bias.view(-1)) + + if hasattr(self, 'downsample'): + proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) + proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) + proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) + + self.downsample.weight.copy_(proj_conv_weight) + self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) + self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode.""" + + def __init__(self, block_units, width_factor): + super().__init__() + width = int(64 * width_factor) + self.width = width + + self.root = nn.Sequential(OrderedDict([ + ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), + ('gn', nn.GroupNorm(32, width, eps=1e-6)), + ('relu', nn.ReLU(inplace=True)), + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) + ])) + + self.body = nn.Sequential(OrderedDict([ + ('block1', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], + ))), + ('block2', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], + ))), + ('block3', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], + ))), + ])) + + def forward(self, x): + features = [] + b, c, in_size, _ = x.size() + x = self.root(x) + features.append(x) + x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) + for i in range(len(self.body)-1): + x = self.body[i](x) + right_size = int(in_size / 4 / (i+1)) + if x.size()[2] != right_size: + pad = right_size - x.size()[2] + assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) + feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) + feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] + else: + feat = x + features.append(feat) + x = self.body[-1](x) + return x, features[::-1] \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/transformer_modules.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/transformer_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..900042884305619e39ad9b792b3b1936dfbd37b3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/transformer_modules.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=4, + num_heads=4, + mlp_ratio=3) + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/unet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..d6dd935cd34a121d8079a8f5184d4ea29e15c8da --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/unet.py @@ -0,0 +1,196 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F + + +class Unet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = image + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + # print("unet feature range:", output.max(), output.min()) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/unet_restormer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/unet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f78aa5f192957b2706c6f691dfc66c427b65ae --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/unet_restormer.py @@ -0,0 +1,234 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + # print("Unet feature range:", output.max(), output.min()) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/unet_transformer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/unet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..982fe23cec321de6b636e04c51b62037054ea432 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/unet_transformer.py @@ -0,0 +1,346 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +对于Unet中的high-level feature, 用Transformer进行特征交互,然后和Transformer之前的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Unet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = image + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack, output = self.encoder(output) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output) # [4, 225, 256], 225 is the number of patches. + patch_embed_input = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1])), int(np.sqrt(feature_output.shape[1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[-1], h, w) # [4,512*2,24,24] + + output = torch.cat((output, feature_output), 1) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output = self.decoder(output, stack) + + return output + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Transformer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/unimodal_transformer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/unimodal_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c99214e10bee2f7a1be49f8560799ffbc14ed0a5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/compare_models/unimodal_transformer.py @@ -0,0 +1,112 @@ +import torch + +from torch import nn +from .transformer_modules import build_transformer +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(): + return ReconstructionHead(input_dim=1, hidden_dim=12) + + + +class CMMT(nn.Module): + + def __init__(self, args): + super(CMMT, self).__init__() + + self.head = build_head() + + HEAD_HIDDEN_DIM = 12 + PATCH_SIZE = 16 + INPUT_SIZE = 240 + OUTPUT_DIM = 1 + + patch_dim = HEAD_HIDDEN_DIM* PATCH_SIZE ** 2 + num_patches = (INPUT_SIZE // PATCH_SIZE) ** 2 + + self.transformer = build_transformer(patch_dim) + + self.patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=PATCH_SIZE, p2=PATCH_SIZE), + ) + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, patch_dim)) + + + self.p1 = PATCH_SIZE + self.p2 = PATCH_SIZE + + self.tail = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x): + + b,_ , h, w = x.shape + + x = self.head(x) + + x= self.patch_embbeding(x) + + x += self.pos_embedding + + x = self.transformer(x) # b HW p1p2c + + c = int(x.shape[2]/(self.p1*self.p2)) + H = int(h/self.p1) + W = int(w/self.p2) + + x = x.reshape(b, H, W, self.p1, self.p2, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w,) + x = self.tail(x) + + return x + + +def build_model(args): + return CMMT(args) + + + + + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/modules.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..bad92a6d6709130c4ad1cc49d5094958d44d7e71 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/modules.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + +USE_PYTORCH_IN = False + + +###################################################################### +# Superclass of all Modules that take two inputs +###################################################################### +class TwoInputModule(nn.Module): + def forward(self, input1, input2): + raise NotImplementedError + +###################################################################### +# A (sort of) hacky way to create a module that takes two inputs (e.g. x and z) +# and returns one output (say o) defined as follows: +# o = module2.forward(module1.forward(x), z) +# Note that module2 MUST support two inputs as well. +###################################################################### +class MergeModule(TwoInputModule): + def __init__(self, module1, module2): + """ module1 could be any module (e.g. Sequential of several modules) + module2 must accept two inputs + """ + super(MergeModule, self).__init__() + self.module1 = module1 + self.module2 = module2 + + def forward(self, input1, input2): + output1 = self.module1.forward(input1) + output2 = self.module2.forward(output1, input2) + return output2 + +###################################################################### +# A (sort of) hacky way to create a container that takes two inputs (e.g. x and z) +# and applies a sequence of modules (exactly like nn.Sequential) but MergeModule +# is one of its submodules it applies it to both inputs +###################################################################### +class TwoInputSequential(nn.Sequential, TwoInputModule): + def __init__(self, *args): + super(TwoInputSequential, self).__init__(*args) + + def forward(self, input1, input2): + """overloads forward function in parent calss""" + + for module in self._modules.values(): + if isinstance(module, TwoInputModule): + input1 = module.forward(input1, input2) + else: + input1 = module.forward(input1) + return input1 + + +###################################################################### +# A standard instance norm module. +# Since the pytorch instance norm used BatchNorm as a base and thus is +# different from the standard implementation. +###################################################################### +class InstanceNorm(nn.Module): + def __init__(self, num_features, affine=True, eps=1e-5): + """`num_features` number of feature channels + """ + super(InstanceNorm, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + + self.scale = Parameter(torch.Tensor(num_features)) + self.shift = Parameter(torch.Tensor(num_features)) + + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + self.scale.data.normal_(mean=0., std=0.02) + self.shift.data.zero_() + + def forward(self, input): + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + centered_x = x_reshaped - mean + std = torch.rsqrt((centered_x ** 2).mean(2, keepdim=True) + self.eps) + norm_features = (centered_x * std).view(*size) + + # broadcast on the batch dimension, hight and width dimensions + if self.affine: + output = norm_features * self.scale[:,None,None] + self.shift[:,None,None] + else: + output = norm_features + + return output + +InstanceNorm2d = nn.InstanceNorm2d if USE_PYTORCH_IN else InstanceNorm + +###################################################################### +# A module implementing conditional instance norm. +# Takes two inputs: x (input features) and z (latent codes) +###################################################################### +class CondInstanceNorm(TwoInputModule): + def __init__(self, x_dim, z_dim, eps=1e-5): + """`x_dim` dimensionality of x input + `z_dim` dimensionality of z latents + """ + super(CondInstanceNorm, self).__init__() + self.eps = eps + self.shift_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + self.scale_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + + def forward(self, input, noise): + + shift = self.shift_conv.forward(noise) + scale = self.scale_conv.forward(noise) + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + var = x_reshaped.var(2, keepdim=True) + std = torch.rsqrt(var + self.eps) + norm_features = ((x_reshaped - mean) * std).view(*size) + output = norm_features * scale + shift + return output + + +###################################################################### +# A modified resnet block which allows for passing additional noise input +# to be used for conditional instance norm +###################################################################### +class CINResnetBlock(TwoInputModule): + def __init__(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + super(CINResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + for idx, module in enumerate(self.conv_block): + self.add_module(str(idx), module) + + def build_conv_block(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [ + MergeModule( + nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(x_dim, z_dim) + ), + nn.ReLU(True) + ] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + InstanceNorm2d(x_dim, affine=True)] + + return TwoInputSequential(*conv_block) + + def forward(self, x, noise): + out = self.conv_block(x, noise) + out = self.relu(x + out) + return out + +###################################################################### +# Define a resnet block +###################################################################### +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [nn.ReLU(True)] + + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = self.conv_block(x) + out = self.relu(x + out) + return out diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/mynet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/mynet.py new file mode 100644 index 0000000000000000000000000000000000000000..91c2966b09a9261e23582c29093b9e59ebd0d4be --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks/mynet.py @@ -0,0 +1,389 @@ +import torch +from torch import nn +from networks import common_freq as common + + +class TwoBranch(nn.Module): + def __init__(self, args): + super(TwoBranch, self).__init__() + + num_group = 4 + num_every_group = args.base_num_every_group + + self.args = args + + self.init_T2_frq_branch(args) + self.init_T2_spa_branch(args, num_every_group) + self.init_T2_fre_spa_fusion(args) + + self.init_T1_frq_branch(args) + self.init_T1_spa_branch(args, num_every_group) + + self.init_modality_fre_fusion(args) + self.init_modality_spa_fusion(args) + + + def init_T2_frq_branch(self, args): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head_fre = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + + self.down1_fre = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down2_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down2_fre = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down3_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down3_fre = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_neck_fre = [common.FreBlock9(args.num_features, args) + ] + self.neck_fre = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_up1_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up1_fre = nn.Sequential(*modules_up1_fre) + self.up1_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_up2_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up2_fre = nn.Sequential(*modules_up2_fre) + self.up2_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_up3_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up3_fre = nn.Sequential(*modules_up3_fre) + self.up3_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + # define tail module + modules_tail_fre = [ + common.ConvBNReLU2D(args.num_features, out_channels=args.num_channels, kernel_size=3, padding=1, + act=args.act)] + self.tail_fre = nn.Sequential(*modules_tail_fre) + + def init_T2_spa_branch(self, args, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down1 = nn.Sequential(*modules_down1) + + + self.down1_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down2 = nn.Sequential(*modules_down2) + + self.down2_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down3 = nn.Sequential(*modules_down3) + self.down3_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.neck = nn.Sequential(*modules_neck) + + self.neck_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_up1 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.up1 = nn.Sequential(*modules_up1) + + self.up1_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_up2 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.up2 = nn.Sequential(*modules_up2) + self.up2_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + + modules_up3 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.up3 = nn.Sequential(*modules_up3) + self.up3_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + # define tail module + modules_tail = [ + common.ConvBNReLU2D(args.num_features, out_channels=args.num_channels, kernel_size=3, padding=1, + act=args.act)] + + self.tail = nn.Sequential(*modules_tail) + + def init_T2_fre_spa_fusion(self, args): + ### T2 frq & spa fusion part + conv_fuse = [] + for i in range(14): + conv_fuse.append(common.FuseBlock7(args.num_features)) + self.conv_fuse = nn.Sequential(*conv_fuse) + + def init_T1_frq_branch(self, args): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head_fre_T1 = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + + self.down1_fre_T1 = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down2_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down2_fre_T1 = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down3_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down3_fre_T1 = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_neck_fre = [common.FreBlock9(args.num_features, args) + ] + self.neck_fre_T1 = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + def init_T1_spa_branch(self, args, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head_T1 = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down1_T1 = nn.Sequential(*modules_down1) + + + self.down1_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down2_T1 = nn.Sequential(*modules_down2) + + self.down2_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down3_T1 = nn.Sequential(*modules_down3) + self.down3_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.neck_T1 = nn.Sequential(*modules_neck) + + self.neck_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + + def init_modality_fre_fusion(self, args): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(args.num_features)) + self.conv_fuse_fre = nn.Sequential(*conv_fuse) + + def init_modality_spa_fusion(self, args): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(args.num_features)) + self.conv_fuse_spa = nn.Sequential(*conv_fuse) + + def forward(self, main, aux): + #### T1 fre encoder + t1_fre = self.head_fre_T1(aux) # 128 + + down1_fre_t1 = self.down1_fre_T1(t1_fre)# 64 + down1_fre_mo_t1 = self.down1_fre_mo_T1(down1_fre_t1) + + down2_fre_t1 = self.down2_fre_T1(down1_fre_mo_t1) # 32 + down2_fre_mo_t1 = self.down2_fre_mo_T1(down2_fre_t1) + + down3_fre_t1 = self.down3_fre_T1(down2_fre_mo_t1) # 16 + down3_fre_mo_t1 = self.down3_fre_mo_T1(down3_fre_t1) + + neck_fre_t1 = self.neck_fre_T1(down3_fre_mo_t1) # 16 + neck_fre_mo_t1 = self.neck_fre_mo_T1(neck_fre_t1) + + + #### T2 fre encoder and T1 & T2 fre fusion + x_fre = self.head_fre(main) # 128 + x_fre_fuse = self.conv_fuse_fre[0](t1_fre, x_fre) + + down1_fre = self.down1_fre(x_fre_fuse)# 64 + down1_fre_mo = self.down1_fre_mo(down1_fre) + down1_fre_mo_fuse = self.conv_fuse_fre[1](down1_fre_mo_t1, down1_fre_mo) + + down2_fre = self.down2_fre(down1_fre_mo_fuse) # 32 + down2_fre_mo = self.down2_fre_mo(down2_fre) + down2_fre_mo_fuse = self.conv_fuse_fre[2](down2_fre_mo_t1, down2_fre_mo) + + down3_fre = self.down3_fre(down2_fre_mo_fuse) # 16 + down3_fre_mo = self.down3_fre_mo(down3_fre) + down3_fre_mo_fuse = self.conv_fuse_fre[3](down3_fre_mo_t1, down3_fre_mo) + + neck_fre = self.neck_fre(down3_fre_mo_fuse) # 16 + neck_fre_mo = self.neck_fre_mo(neck_fre) + neck_fre_mo_fuse = self.conv_fuse_fre[4](neck_fre_mo_t1, neck_fre_mo) + + + #### T2 fre decoder + neck_fre_mo = neck_fre_mo_fuse + down3_fre_mo_fuse + + up1_fre = self.up1_fre(neck_fre_mo) # 32 + up1_fre_mo = self.up1_fre_mo(up1_fre) + up1_fre_mo = up1_fre_mo + down2_fre_mo_fuse + + up2_fre = self.up2_fre(up1_fre_mo) # 64 + up2_fre_mo = self.up2_fre_mo(up2_fre) + up2_fre_mo = up2_fre_mo + down1_fre_mo_fuse + + up3_fre = self.up3_fre(up2_fre_mo) # 128 + up3_fre_mo = self.up3_fre_mo(up3_fre) + up3_fre_mo = up3_fre_mo + x_fre_fuse + + res_fre = self.tail_fre(up3_fre_mo) + + #### T1 spa encoder + x_t1 = self.head_T1(aux) # 128 + + down1_t1 = self.down1_T1(x_t1) # 64 + down1_mo_t1 = self.down1_mo_T1(down1_t1) + + down2_t1 = self.down2_T1(down1_mo_t1) # 32 + down2_mo_t1 = self.down2_mo_T1(down2_t1) # 32 + + down3_t1 = self.down3_T1(down2_mo_t1) # 16 + down3_mo_t1 = self.down3_mo_T1(down3_t1) # 16 + + neck_t1 = self.neck_T1(down3_mo_t1) # 16 + neck_mo_t1 = self.neck_mo_T1(neck_t1) + + #### T2 spa encoder and fusion + x = self.head(main) # 128 + + x_fuse = self.conv_fuse_spa[0](x_t1, x) + down1 = self.down1(x_fuse) # 64 + down1_fuse = self.conv_fuse[0](down1_fre, down1) + down1_mo = self.down1_mo(down1_fuse) + down1_fuse_mo = self.conv_fuse[1](down1_fre_mo_fuse, down1_mo) + + down1_fuse_mo_fuse = self.conv_fuse_spa[1](down1_mo_t1, down1_fuse_mo) + down2 = self.down2(down1_fuse_mo_fuse) # 32 + down2_fuse = self.conv_fuse[2](down2_fre, down2) + down2_mo = self.down2_mo(down2_fuse) # 32 + down2_fuse_mo = self.conv_fuse[3](down2_fre_mo, down2_mo) + + down2_fuse_mo_fuse = self.conv_fuse_spa[2](down2_mo_t1, down2_fuse_mo) + down3 = self.down3(down2_fuse_mo_fuse) # 16 + down3_fuse = self.conv_fuse[4](down3_fre, down3) + down3_mo = self.down3_mo(down3_fuse) # 16 + down3_fuse_mo = self.conv_fuse[5](down3_fre_mo, down3_mo) + + down3_fuse_mo_fuse = self.conv_fuse_spa[3](down3_mo_t1, down3_fuse_mo) + neck = self.neck(down3_fuse_mo_fuse) # 16 + neck_fuse = self.conv_fuse[6](neck_fre, neck) + neck_mo = self.neck_mo(neck_fuse) + neck_mo = neck_mo + down3_mo + neck_fuse_mo = self.conv_fuse[7](neck_fre_mo, neck_mo) + + neck_fuse_mo_fuse = self.conv_fuse_spa[4](neck_mo_t1, neck_fuse_mo) + #### T2 spa decoder + up1 = self.up1(neck_fuse_mo_fuse) # 32 + up1_fuse = self.conv_fuse[8](up1_fre, up1) + up1_mo = self.up1_mo(up1_fuse) + up1_mo = up1_mo + down2_mo + up1_fuse_mo = self.conv_fuse[9](up1_fre_mo, up1_mo) + + up2 = self.up2(up1_fuse_mo) # 64 + up2_fuse = self.conv_fuse[10](up2_fre, up2) + up2_mo = self.up2_mo(up2_fuse) + up2_mo = up2_mo + down1_mo + up2_fuse_mo = self.conv_fuse[11](up2_fre_mo, up2_mo) + + up3 = self.up3(up2_fuse_mo) # 128 + + up3_fuse = self.conv_fuse[12](up3_fre, up3) + up3_mo = self.up3_mo(up3_fuse) + + up3_mo = up3_mo + x + up3_fuse_mo = self.conv_fuse[13](up3_fre_mo, up3_mo) + # import matplotlib.pyplot as plt + # plt.axis('off') + # plt.imshow((255*up3_fre_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_fre_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + + # plt.axis('off') + # plt.imshow((255*up3_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + + # plt.axis('off') + # plt.imshow((255*up3_fuse_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_fuse_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + # breakpoint() + + res = self.tail(up3_fuse_mo) + + return {'img_out': res + main, 'img_fre': res_fre + main} + +def make_model(args): + return TwoBranch(args) + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/__init__.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/__pycache__/__init__.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec7c2be5c402cfec0f2e69f12c7f5f8009f5c738 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/__pycache__/__init__.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/__pycache__/common_freq.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/__pycache__/common_freq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d46ca9e28b71351f9198ec503d74369c13cb2118 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/__pycache__/common_freq.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/__pycache__/mynet.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/__pycache__/mynet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2dc72877432da5f84820859ecef1638e99aaa61c Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/__pycache__/mynet.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/common_freq.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/common_freq.py new file mode 100644 index 0000000000000000000000000000000000000000..559392e6252c5be1e8b94d4d3895771450160d67 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/common_freq.py @@ -0,0 +1,411 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange + +class Scale(nn.Module): + + def __init__(self, init_value=1e-3): + super().__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +class invPixelShuffle(nn.Module): + def __init__(self, ratio=2): + super(invPixelShuffle, self).__init__() + self.ratio = ratio + + def forward(self, tensor): + ratio = self.ratio + b = tensor.size(0) + ch = tensor.size(1) + y = tensor.size(2) + x = tensor.size(3) + assert x % ratio == 0 and y % ratio == 0, 'x, y, ratio : {}, {}, {}'.format(x, y, ratio) + tensor = tensor.view(b, ch, y // ratio, ratio, x // ratio, ratio).permute(0, 1, 3, 5, 2, 4) + return tensor.contiguous().view(b, -1, y // ratio, x // ratio) + + +class UpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(in_channels=n_feats, out_channels=4 * n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PixelShuffle(upscale_factor=2)) + m.append(nn.PReLU()) + super(UpSampler, self).__init__(*m) + + +class InvUpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(invPixelShuffle(2)) + m.append(nn.Conv2d(in_channels=n_feats * 4, out_channels=n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PReLU()) + super(InvUpSampler, self).__init__(*m) + + + +class AdaptiveNorm(nn.Module): + def __init__(self, n): + super(AdaptiveNorm, self).__init__() + + self.w_0 = nn.Parameter(torch.Tensor([1.0])) + self.w_1 = nn.Parameter(torch.Tensor([0.0])) + + self.bn = nn.BatchNorm2d(n, momentum=0.999, eps=0.001) + + def forward(self, x): + return self.w_0 * x + self.w_1 * self.bn(x) + + +class ConvBNReLU2D(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, + bias=False, act=None, norm=None, temp_channel=None): + super(ConvBNReLU2D, self).__init__() + + if not isinstance(temp_channel, type(None)): + self.temb_proj = torch.nn.Linear(temp_channel, out_channels) + + self.layers = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + self.act = None + self.norm = None + if norm == 'BN': + self.norm = torch.nn.BatchNorm2d(out_channels) + elif norm == 'IN': + self.norm = torch.nn.InstanceNorm2d(out_channels) + elif norm == 'GN': + self.norm = torch.nn.GroupNorm(2, out_channels) + elif norm == 'WN': + self.layers = torch.nn.utils.weight_norm(self.layers) + elif norm == 'Adaptive': + self.norm = AdaptiveNorm(n=out_channels) + + if act == 'PReLU': + self.act = torch.nn.PReLU() + elif act == 'SELU': + self.act = torch.nn.SELU(True) + elif act == 'LeakyReLU': + self.act = torch.nn.LeakyReLU(negative_slope=0.02, inplace=True) + elif act == 'ELU': + self.act = torch.nn.ELU(inplace=True) + elif act == 'ReLU': + self.act = torch.nn.ReLU(True) + elif act == 'Tanh': + self.act = torch.nn.Tanh() + elif act == 'Sigmoid': + self.act = torch.nn.Sigmoid() + elif act == 'SoftMax': + self.act = torch.nn.Softmax2d() + + def forward(self, inputs, temb=None): + + out = self.layers(inputs) + + if not isinstance(temb, type(None)): + out = out + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + if self.norm is not None: + out = self.norm(out) + if self.act is not None: + out = self.act(out) + return out + + +class ResBlock(nn.Module): + def __init__(self, num_features, temp_channel=None): + super(ResBlock, self).__init__() + self.layers1 = ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, act='ReLU', padding=1, temp_channel=temp_channel) + self.layers2 = ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, padding=1) + + + def forward(self, inputs, temp=None): + x = self.layers1(inputs, temp) + x = self.layers2(x) + + return F.relu(x + inputs) + + +class DownSample(nn.Module): + def __init__(self, num_features, act, norm, scale=2): + super(DownSample, self).__init__() + if scale == 1: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + else: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + invPixelShuffle(ratio=scale), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + + def forward(self, inputs): + return self.layers(inputs) + + +class ResidualGroup(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act,norm, n_resblocks, temp_channel=None): + super(ResidualGroup, self).__init__() + + self.head = ResBlock(n_feat, temp_channel) # Use to be two + + modules_body = [ResBlock(n_feat) for _ in range(n_resblocks - 1)] + + modules_body.append(ConvBNReLU2D(n_feat, n_feat, kernel_size, padding=1, act=act, norm=norm)) + self.body = nn.Sequential(*modules_body) + self.re_scale = Scale(1) + + def forward(self, x, t=None): + x = self.head(x, t) + res = self.body(x) + return res + self.re_scale(x) + + + +class FreBlock9(nn.Module): + def __init__(self, channels, args): + super(FreBlock9, self).__init__() + + self.temb_proj = torch.nn.Linear(args.temb_channels, channels) + + self.fpre = nn.Conv2d(channels, channels, 1, 1, 0) + self.amp_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.pha_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.post = nn.Conv2d(channels, channels, 1, 1, 0) + + + def forward(self, x, temb=None): + # print("x: ", x.shape) + _, _, H, W = x.shape + + if not isinstance(temb, type(None)): + x = x + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + msF = torch.fft.rfft2(self.fpre(x)+1e-8, norm='backward') + + msF_amp = torch.abs(msF) + msF_pha = torch.angle(msF) + # print("msf_amp: ", msF_amp.shape) + amp_fuse = self.amp_fuse(msF_amp) + # print(amp_fuse.shape, msF_amp.shape) + amp_fuse = amp_fuse + msF_amp + pha_fuse = self.pha_fuse(msF_pha) + pha_fuse = pha_fuse + msF_pha + + real = amp_fuse * torch.cos(pha_fuse)+1e-8 + imag = amp_fuse * torch.sin(pha_fuse)+1e-8 + out = torch.complex(real, imag)+1e-8 + out = torch.abs(torch.fft.irfft2(out, s=(H, W), norm='backward')) + out = self.post(out) + out = out + x + out = torch.nan_to_num(out, nan=1e-5, posinf=1e-5, neginf=1e-5) + # print("out: ", out.shape) + return out + +class Attention(nn.Module): + def __init__(self, dim=64, num_heads=8, bias=False): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) + self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias) + self.q = nn.Conv2d(dim, dim , kernel_size=1, bias=bias) + self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x, y): + b, c, h, w = x.shape + + kv = self.kv_dwconv(self.kv(y)) + k, v = kv.chunk(2, dim=1) + q = self.q_dwconv(self.q(x)) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + +class FuseBlock7(nn.Module): + def __init__(self, channels): + super(FuseBlock7, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fre_att = Attention(dim=channels) + self.spa_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + ori = spa + fre = self.fre(fre) + spa = self.spa(spa) + fre = self.fre_att(fre, spa)+fre + spa = self.fre_att(spa, fre)+spa + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class FuseBlock6(nn.Module): + def __init__(self, channels): + super(FuseBlock6, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock7(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock7, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t1_att = Attention(dim=channels) + self.t2_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + t1 = self.t1_att(t1, t2)+t1 + t2 = self.t2_att(t2, t1)+t2 + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock6(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock6, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/ARTUnet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/ARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ae23fed16b0d9454a5ea09f17b4322f1e9d7e928 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/ARTUnet.py @@ -0,0 +1,751 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + # print("x shape:", x.shape) + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +########################################################################## +class CS_TransformerBlock(nn.Module): + r""" 在ART Transformer Block的基础上添加了channel attention的分支。 + 参考论文: Dual Aggregation Transformer for Image Super-Resolution + 及其代码: https://github.com/zhengchen1999/DAT/blob/main/basicsr/archs/dat_arch.py + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + + self.dwconv = nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim), + nn.InstanceNorm2d(dim), + nn.GELU()) + + self.channel_interaction = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(dim, dim // 8, kernel_size=1), + # nn.InstanceNorm2d(dim // 8), + nn.GELU(), + nn.Conv2d(dim // 8, dim, kernel_size=1)) + + self.spatial_interaction = nn.Sequential( + nn.Conv2d(dim, dim // 16, kernel_size=1), + nn.InstanceNorm2d(dim // 16), + nn.GELU(), + nn.Conv2d(dim // 16, 1, kernel_size=1)) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # print("x before dwconv:", x.shape) + # convolution output + conv_x = self.dwconv(x.permute(0, 3, 1, 2)) + # conv_x = x.permute(0, 3, 1, 2) + # print("x after dwconv:", x.shape) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + attened_x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + attened_x = attened_x[:, :H, :W, :].contiguous() + + attened_x = attened_x.view(B, H * W, C) + + # Adaptive Interaction Module (AIM) + # C-Map (before sigmoid) + channel_map = self.channel_interaction(conv_x).permute(0, 2, 3, 1).contiguous().view(B, 1, C) + # S-Map (before sigmoid) + attention_reshape = attened_x.transpose(-2,-1).contiguous().view(B, C, H, W) + spatial_map = self.spatial_interaction(attention_reshape) + + # C-I + attened_x = attened_x * torch.sigmoid(channel_map) + # S-I + conv_x = torch.sigmoid(spatial_map) * conv_x + conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, H * W, C) + + x = attened_x + conv_x + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/ARTUnet_new.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/ARTUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9d210d46e2a4d73781b3b16b63a53fdc0f25b9 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/ARTUnet_new.py @@ -0,0 +1,823 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +Unet的格式构建image restoration 网络。每个transformer block中都是dense self-attention和sparse self-attention交替连接。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True, visualize_attention_maps=False): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + self.visualize_attention_maps = visualize_attention_maps + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + # assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + # qkv: 3, B_, num_heads, N, C // num_heads + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + # B_, num_heads, N, C // num_heads * + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + attn_map = attn.detach().cpu().numpy().mean(axis=1) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) # (B * nP, nHead, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + # attn: B * nP, num_heads, N, N + # v: B * nP, num_heads, N, C // num_heads + # -attn-@-v-> B * nP, num_heads, N, C // num_heads + # -transpose-> B * nP, N, num_heads, C // num_heads + # -reshape-> B * nP, N, C + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, G ** 2, G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, Gh * Gw, Gh * Gw) + else: + x = self.attn(x, Gh, Gw, mask=attn_mask) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +class ConcatTransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + ################## + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + ################## + + x = torch.cat((x, y), dim=1) + + shortcut = x + x = self.norm1(x) + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, 2 * Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, 2 * G ** 2, 2 * G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, 2 * Gh * Gw, 2 * Gh * Gw) + else: + x = self.attn(x, 2 * Gh, Gw, mask=attn_mask) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + output = x + + # print(x.shape, Gh, Gw) + x = output[:, :Gh * Gw, :] + y = output[:, Gh * Gw:, :] + assert x.shape == y.shape + + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = y.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = y.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, y, attn_map + else: + return x, y + + def get_patches(self, x, x_size): + H, W = x_size + B, H, W, C = x.shape + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + return x, attn_mask, (Hd, Wd), (Gh, Gw), (pad_r, pad_b), G, I + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/ART_Restormer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/ART_Restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..4c23123f0ac1a8e82e2bf907658ce8d91951c3e4 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/ART_Restormer.py @@ -0,0 +1,592 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer串联而成。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[4, 4, 4, 4], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + + ######################################### Encoder level1 ############################################ + ### ART encoder block + for layer in self.ART_encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + + ### Restormer encoder block + out_enc_level1 = rearrange(out_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = self.Restormer_encoder_level1(out_enc_level1) + # stack.append(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.ART_encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + out_enc_level2 = rearrange(out_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = self.Restormer_encoder_level2(out_enc_level2) + # stack.append(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.ART_encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + out_enc_level3 = rearrange(out_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = self.Restormer_encoder_level3(out_enc_level3) + # stack.append(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.ART_latent: + latent = layer(latent, [H // 8, W // 8]) + + ### Restormer encoder block + latent = rearrange(latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = self.Restormer_latent(latent) + # stack.append(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + out_dec_level3 = inp_dec_level3 + for layer in self.ART_decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + out_dec_level3 = rearrange(out_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + out_dec_level3 = self.Restormer_decoder_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + out_dec_level2 = inp_dec_level2 + for layer in self.ART_decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + out_dec_level2 = rearrange(out_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + out_dec_level2 = self.Restormer_decoder_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + out_dec_level1 = inp_dec_level1 + for layer in self.ART_decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + ### Restormer decoder block + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_decoder_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_refinement(out_dec_level1) + + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/ART_Restormer_v2.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/ART_Restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..de8f921e81d79cb7af14a75326f0ddaaf3b3f4c3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/ART_Restormer_v2.py @@ -0,0 +1,663 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer并联而成。 +输入特征分别经过ART和Restormer两个分支提取特征, 将两者的特征concat之后,用conv层变换得到该block的输出特征。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.enc_fuse_level1 = nn.Conv2d(int(dim * 2 ** 1), dim, kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.enc_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.enc_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + self.enc_fuse_level4 = nn.Conv2d(int(dim * 2 ** 4), int(dim * 2 ** 3), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.dec_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.dec_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.dec_fuse_level1 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + # self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + # bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + # self.fuse_refinement = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + # def forward(self, inp_img, aux_img): + # stack = [] + # bs, _, H, W = inp_img.shape + + # fuse_img = torch.cat((inp_img, aux_img), 1) + # inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + + def forward(self, inp_img): + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + + ######################################### Encoder level1 ############################################ + art_enc_level1 = inp_enc_level1 + restor_enc_level1 = inp_enc_level1 + + ### ART encoder block + for layer in self.ART_encoder_level1: + art_enc_level1 = layer(art_enc_level1, [H, W]) + + ### Restormer encoder block + restor_enc_level1 = rearrange(restor_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_enc_level1 = self.Restormer_encoder_level1(restor_enc_level1) ### (b, c, h, w) + + art_enc_level1 = rearrange(art_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = torch.cat((art_enc_level1, restor_enc_level1), 1) + out_enc_level1 = self.enc_fuse_level1(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + art_enc_level2 = inp_enc_level2 + restor_enc_level2 = inp_enc_level2 + + for layer in self.ART_encoder_level2: + art_enc_level2 = layer(art_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + restor_enc_level2 = rearrange(restor_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + restor_enc_level2 = self.Restormer_encoder_level2(restor_enc_level2) + + art_enc_level2 = rearrange(art_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = torch.cat((art_enc_level2, restor_enc_level2), 1) + out_enc_level2 = self.enc_fuse_level2(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + art_enc_level3 = inp_enc_level3 + restor_enc_level3 = inp_enc_level3 + + for layer in self.ART_encoder_level3: + art_enc_level3 = layer(art_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + restor_enc_level3 = rearrange(restor_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + restor_enc_level3 = self.Restormer_encoder_level3(restor_enc_level3) + + art_enc_level3 = rearrange(art_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = torch.cat((art_enc_level3, restor_enc_level3), 1) + out_enc_level3 = self.enc_fuse_level3(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + art_latent = inp_enc_level4 + restor_latent = inp_enc_level4 + + for layer in self.ART_latent: + art_latent = layer(art_latent, [H // 8, W // 8]) + + ### Restormer encoder block + restor_latent = rearrange(restor_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + restor_latent = self.Restormer_latent(restor_latent) + + art_latent = rearrange(art_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = torch.cat((art_latent, restor_latent), 1) + latent = self.enc_fuse_level4(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + art_dec_level3 = inp_dec_level3 + restor_dec_level3 = inp_dec_level3 + + for layer in self.ART_decoder_level3: + art_dec_level3 = layer(art_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + restor_dec_level3 = rearrange(restor_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + restor_dec_level3 = self.Restormer_decoder_level3(restor_dec_level3) + + art_dec_level3 = rearrange(art_dec_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_dec_level3 = torch.cat((art_dec_level3, restor_dec_level3), 1) + out_dec_level3 = self.dec_fuse_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + art_dec_level2 = inp_dec_level2 + restor_dec_level2 = inp_dec_level2 + + for layer in self.ART_decoder_level2: + art_dec_level2 = layer(art_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + restor_dec_level2 = rearrange(restor_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + restor_dec_level2 = self.Restormer_decoder_level2(restor_dec_level2) + + art_dec_level2 = rearrange(art_dec_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_dec_level2 = torch.cat((art_dec_level2, restor_dec_level2), 1) + out_dec_level2 = self.dec_fuse_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + art_dec_level1 = inp_dec_level1 + restor_dec_level1 = inp_dec_level1 + + for layer in self.ART_decoder_level1: + art_dec_level1 = layer(art_dec_level1, [H, W]) + + ### Restormer decoder block + restor_dec_level1 = rearrange(restor_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_dec_level1 = self.Restormer_decoder_level1(restor_dec_level1) + + art_dec_level1 = rearrange(art_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = torch.cat((art_dec_level1, restor_dec_level1), 1) + out_dec_level1 = self.dec_fuse_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +# def build_model(args): +# return ART_Restormer( +# inp_channels=2, +# out_channels=1, +# dim=32, +# num_art_blocks=[2, 4, 4, 6], +# num_restormer_blocks=[2, 2, 2, 2], +# num_refinement_blocks=4, +# heads=[1, 2, 4, 8], +# window_size=[10, 10, 10, 10], mlp_ratio=4., +# qkv_bias=True, qk_scale=None, +# interval=[24, 12, 6, 3], +# ffn_expansion_factor = 2.66, +# LayerNorm_type = 'WithBias', ## Other option 'BiasFree' +# bias=False) + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/ARTfuse_layer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/ARTfuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..78af0b27528b50db897936f6b8b4eee51d566b1c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/ARTfuse_layer.py @@ -0,0 +1,705 @@ +""" +ART layer in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input to the Attention layer:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + # print("q range:", q.max(), q.min()) + # print("k range:", k.max(), k.min()) + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + # print("attention matrix:", attn.shape, attn) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + # print("feature before proj-layer:", x.max(), x.min()) + x = self.proj(x) + # print("feature after proj-layer:", x.max(), x.min()) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pre_norm=True): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + # self.conv = nn.Conv2d(dim, dim, kernel_size=1) + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + # self.conv = ConvBlock(self.dim, self.dim, drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.bn_norm = nn.BatchNorm2d(dim) + self.pre_norm = pre_norm + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + # x = self.conv(x) + shortcut = x + if self.pre_norm: + # print("using pre_norm") + x = self.norm1(x) ## 归一化之后特征范围变得非常小 + x = x.view(B, H, W, C) + # print("normalized input range:", x.max(), x.min(), x.mean()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + # print("range of feature before layer:", x.max(), x.min(), x.mean()) + # x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + + # x_bn = self.bn_norm(x.permute(0, 3, 1, 2)) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + # print("[ART layer] input and output feature difference:", torch.mean(torch.abs(shortcut - x))) + # print("range of feature before ART layer:", shortcut.max(), shortcut.min(), shortcut.mean()) + # print("range of feature after ART layer:", x.max(), x.min(), x.mean()) + if self.pre_norm: + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + # print("using post_norm") + x = self.norm1(shortcut + self.drop_path(x)) + x = self.norm2(x + self.drop_path(self.mlp(x))) + # x = shortcut + self.drop_path(x) + # x = x + self.drop_path(self.mlp(x)) + + # print("range of fused feature:", x.max(), x.min()) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + return self.layers(image) + + + +########################################################################## +class Cross_TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +class Cross_TransformerBlock_v2(nn.Module): + r""" ART Transformer Block. + 将两个模态的特征沿着channel方向concat, 一起取window. 之后把取出来的两个模态中同一个window的所有patches合并,送到transformer处理。 + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + xy = torch.cat((x, y), -1) + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + xy = F.pad(xy, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + xy = xy.reshape(B, Hd // G, G, Wd // G, G, 2*C).permute(0, 1, 3, 2, 4, 5).contiguous() + xy = xy.reshape(B * Hd * Wd // G ** 2, G ** 2, 2*C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + xy = xy.reshape(B, Gh, I, Gw, I, 2*C).permute(0, 2, 4, 1, 3, 5).contiguous() + xy = xy.reshape(B * I * I, Gh * Gw, 2*C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + x = xy[:, :, :C] + y = xy[:, :, C:] + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/DataConsistency.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/DataConsistency.py new file mode 100644 index 0000000000000000000000000000000000000000..5f460e9acabcb33e1a3f812eb9acaeb337da446c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/DataConsistency.py @@ -0,0 +1,48 @@ +""" +Created: DataConsistency @ Xiyang Cai, 2023/09/09 + +Data consistency layer for k-space signal. + +Ref: DataConsistency in DuDoRNet (https://github.com/bbbbbbzhou/DuDoRNet) + +""" + +import torch +from torch import nn +from einops import repeat + + +def data_consistency(k, k0, mask): + """ + k - input in k-space + k0 - initially sampled elements in k-space + mask - corresponding nonzero location + """ + out = (1 - mask) * k + mask * k0 + return out + + +class DataConsistency(nn.Module): + """ + Create data consistency operator + """ + + def __init__(self): + super(DataConsistency, self).__init__() + + def forward(self, k, k0, mask): + """ + k - input in frequency domain, of shape (n, nx, ny, 2) + k0 - initially sampled elements in k-space + mask - corresponding nonzero location (n, 1, len, 1) + """ + + if k.dim() != 4: # input is 2D + raise ValueError("error in data consistency layer!") + + # mask = repeat(mask.squeeze(1, 3), 'b x -> b x y c', y=k.shape[1], c=2) + mask = torch.tile(mask, (1, mask.shape[2], 1, k.shape[-1])) ### [n, 320, 320, 2] + # print("k and k0 shape:", k.shape, k0.shape) + out = data_consistency(k, k0, mask) + + return out, mask diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/DuDo_mUnet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/DuDo_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a56069a27f0cdd392482b41dc4e3b79252295487 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/DuDo_mUnet.py @@ -0,0 +1,359 @@ +""" +Xiaohan Xing, 2023/10/25 +传入两个模态的kspace data, 经过kspace network进行重建。 +然后把重建之后的kspace data变换回图像,然后送到image domain的网络进行图像重建。 +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + +from .Kspace_mUnet import kspace_mmUnet + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = kspace_ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = kspace_ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = kspace_ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = kspace_ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class DuDo_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 3, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + # self.kspace_net = mmConvKSpace() + self.kspace_net = kspace_mmUnet(args) + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + ) + ) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + stack = [] + feature_stack = [] + # output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + output = torch.cat((image, LR_image, aux_image), 1) #.detach() + + # print("LR image range:", LR_image.max(), LR_image.min()) + # print("kspace_recon image range:", image.max(), image.min()) + # print("aux_image range:", aux_image.max(), aux_image.min()) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output, image, recon_kspace, masked_kspace, data_stats + + + +class kspace_ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return DuDo_mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/DuDo_mUnet_ARTfusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/DuDo_mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6df67b63140ca51bd7b8664151ab4b1b7a6305d5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/DuDo_mUnet_ARTfusion.py @@ -0,0 +1,369 @@ +""" +Xiaohan Xing, 2023/11/08 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + ARTfusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + output1, output2 = aux_image.detach(), LR_image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, image, recon_kspace, masked_kspace, data_stats + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/DuDo_mUnet_CatFusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/DuDo_mUnet_CatFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c2b86b6e2031b1af07ab7cb58efc182b7a2673 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/DuDo_mUnet_CatFusion.py @@ -0,0 +1,338 @@ +""" +Xiaohan Xing, 2023/11/10 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + multi layer concat fusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_CATfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t2_gt: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + t2_image = t2_gt * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + # # print("aux_image range:", aux_image.max(), aux_image.min()) + # print("LR_image range:", LR_image.max(), LR_image.min()) + # print("kspace recon_img range:", image.max(), image.min()) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + t2_gt_image = self.normalize(t2_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + t2_gt_image = torch.clamp(t2_gt_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + # output1, output2 = aux_image.detach(), LR_image.detach() + output1, output2 = aux_image.detach(), image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, aux_image, t2_gt_image, image, LR_image, recon_kspace, masked_kspace, data_stats + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_CATfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/Kspace_ConvNet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/Kspace_ConvNet.py new file mode 100644 index 0000000000000000000000000000000000000000..918cbe598a62653394c8c4eaf99cd10a45c064ca --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/Kspace_ConvNet.py @@ -0,0 +1,140 @@ +""" +2023/10/05 +Xiaohan Xing +Build a simple network with several convolution layers. +Input: concat (undersampled kspace of FS-PD, fully-sampled kspace of PD modality) +Output: interpolated kspace of FS-PD +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, args, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + # self.conv_blocks = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # self.conv_blocks.append(ConvBlock(self.chans, self.chans * 2, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans * 2, self.chans, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans, self.out_chans, drop_prob)) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + # print("output of the three conv blocks:", output1.shape, output2.shape, output3.shape) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("conv1_output range:", output1.max(), output1.min()) + # print("conv2_output range:", output2.max(), output2.min()) + # print("conv3_output range:", output3.max(), output3.min()) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + # print("output before DC layer:", output.shape) + # print("output range before DC layer:", output.max(), output.min()) + # output = output + kspace ## residual connection + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + # mask_output = output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +def build_model(args): + return mmConvKSpace(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/Kspace_mUnet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/Kspace_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2db9a357798be4abd4cfbd44e9207555770b0456 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/Kspace_mUnet.py @@ -0,0 +1,202 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # print("output:", output.shape) + # print("kspace:", kspace.shape) + # print("mask:", mask.shape) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/Kspace_mUnet_new.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/Kspace_mUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ccfd28dc198ffeff2259a5fbed45dabdc76819 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/Kspace_mUnet_new.py @@ -0,0 +1,200 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # print("using Unet without Relu layers") + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + # output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # # print("output:", output.shape) + # # print("kspace:", kspace.shape) + # # print("mask:", mask.shape) + # mask_output = mask * output + + return output # .permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/MINet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..4eeeb063b12be1dc594e8e076e6a59e454fcfa0b --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/MINet.py @@ -0,0 +1,356 @@ +""" +Compare with the method "Multi-contrast MRI Super-Resolution via a Multi-stage Integration Network". +The code is from https://github.com/chunmeifeng/MINet/edit/main/fastmri/models/MINet.py. +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F + + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 1 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + # print("modules_tail:", modules_tail) + # self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Sigmoid()) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + + + +class MINet(nn.Module): + def __init__(self, args, n_resgroups=3, n_resblocks=3, n_feats=64): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + # print("x1:", x1.shape, "x2:", x2.shape) + + ### 源代码中是super-resolution任务,所以self.tail对图像进行unsampling. 我们做reconstruction把这步去掉 + x2 = self.tail(x2) + + + resT1 = x1 + resT2 = x2 + + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + # print("t1s:", [item.shape for item in t1s]) + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + # print("resT1 range:", resT1.max(), resT1.min()) + # print("resT2 range:", resT2.max(), resT2.min()) + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1, x2 #x1=pd x2=pdfs + + +def build_model(args): + return MINet(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/MINet_common.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/MINet_common.py new file mode 100644 index 0000000000000000000000000000000000000000..631f3036db4edf8eab00d70c66d2058a3b65cd59 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/MINet_common.py @@ -0,0 +1,90 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias) + +class MeanShift(nn.Conv2d): + def __init__( + self, rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): + + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std + for p in self.parameters(): + p.requires_grad = False + +class BasicBlock(nn.Sequential): + def __init__( + self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, + bn=True, act=nn.ReLU(True)): + + m = [conv(in_channels, out_channels, kernel_size, bias=bias)] + if bn: + m.append(nn.BatchNorm2d(out_channels)) + if act is not None: + m.append(act) + + super(BasicBlock, self).__init__(*m) + +class ResBlock(nn.Module): + def __init__( + self, conv, n_feats, kernel_size, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + + return res + + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + # print("upsampler:", m) + + super(Upsampler, self).__init__(*m) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/SANet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/SANet.py new file mode 100644 index 0000000000000000000000000000000000000000..bfac3590e248c9532b002885c6c0b7a55ab1400a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/SANet.py @@ -0,0 +1,392 @@ +""" +This script contains the codes of "Exploring Separable Attention for Multi-Contrast MR Image Super-Resolution (TNNLS 2023)" +https://github.com/chunmeifeng/SANet/blob/main/fastmri/models/SANet.py +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import os + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,scale,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups #10 + self.n_resblocks = n_resblocks #20 + self.n_feats = n_feats #64 + kernel_size = 3 + reduction = 16 #16 + # scale = args.scale[0] + + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class Seatt(nn.Module): + def __init__(self, in_c): + super(Seatt, self).__init__() + self.reduce = nn.Conv2d(in_c * 2, 32, 1) + self.ff_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.bf_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.rgbd_pred_layer = Pred_Layer(32 * 2) + self.convq = nn.Conv2d(32, 32, 3, 1, 1) + + def forward(self, rgb_feat, dep_feat, pred): + feat = torch.cat((rgb_feat, dep_feat), 1) + + # vis_attention_map_rgb_feat(rgb_feat[0][0]) + # vis_attention_map_dep_feat(dep_feat[0][0]) + # vis_attention_map_feat(feat[0][0]) + + + feat = self.reduce(feat) + [_, _, H, W] = feat.size() + pred = torch.sigmoid(pred) + + ni_pred = 1 - pred + + ff_feat = self.ff_conv(feat * pred) + bf_feat = self.bf_conv(feat * (1 - pred)) + new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1)) + + return new_pred + +class SANet(nn.Module): + def __init__(self, args, scale=1, n_resgroups=3, n_resblocks=3, n_feats=64): + super(SANet, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + + cs = [64,64,64,64] + self.Seatts = nn.ModuleList([Seatt(c) for c in cs]) + + self.net1 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + print("nlayer:",nlayer) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=3, padding=1) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2, Seatt, map_conv in zip(self.net1.body._modules.items(), self.net2.body._modules.items(), self.Seatts, self.map_convs): + + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + pred = map_conv(resT1) + res = Seatt(resT1,resT2,pred) + + resT2 = res+resT2 + + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + + return x1,x2 + + +def build_model(args): + return SANet(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/SwinFuse_layer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/SwinFuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3dca8ba793d92961bc3f1d987e211fe542b303 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/SwinFuse_layer.py @@ -0,0 +1,1310 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + # print("x shape:", x.shape) + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + # print("num heads:", self.num_heads) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + # print("x_windows:", x_windows.shape) + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + + # # 不使用残差连接 + # x = self.residual_group_A(x, x_size) + # y = self.residual_group_B(y, x_size) + + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion_layer(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Fusion_num_heads=[6, 6], + window_size=7, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, img_range=1., resi_connection='1conv', + **kwargs): + super(SwinFusion_layer, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + + ### fusion layer. + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + # print("fusion layers:", len(self.layers_Fusion)) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + + def forward_features_Fusion(self, x, y): + + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + # x = torch.cat([x, y], 1) + # ## Downsample the feature in the channel dimension + # x = self.lrelu(self.conv_after_body_Fusion(x)) + # print("x range after fusion:", x.max(), x.min()) + # print("y range after fusion:", y.max(), y.min()) + # print("x:", x.shape) + + return x, y + + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + print("modality1 input:", x.max(), x.min()) + print("modality2 input:", y.max(), y.min()) + + # multi-modal feature fusion. + x, y = self.forward_features_Fusion(x, y) ### 返回两个模态融合之后的features. + print("modality1 output:", x.max(), x.min()) + print("modality2 output:", y.max(), y.min()) + + return x[:, :, :H, :W], y[:, :, :H, :W] + + + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + + +if __name__ == '__main__': + width, height = 120, 120 + swinfuse = SwinFusion_layer(img_size=(width, height), window_size=8, depths=[6, 6, 6], embed_dim=60, num_heads=[6, 6, 6]) + # print("SwinFusion layer:", swinfuse) + + A = torch.randn((4, 60, width, height)) + B = torch.randn((4, 60, width, height)) + + x = swinfuse(A, B) + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/SwinFusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/SwinFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..3e72c2a3e8859423f3544043e8b9feaec002d0e2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/SwinFusion.py @@ -0,0 +1,1468 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # print("qkv after reshape:", qkv.shape) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + # 不使用残差连接 + # print("input x shape:", x.shape) + x = self.residual_group_A(x, x_size) + y = self.residual_group_B(y, x_size) + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Re_depths=[4], + Ex_num_heads=[6], Fusion_num_heads=[6, 6], Re_num_heads=[6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinFusion, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + print('in_chans: ', in_chans) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + ####修改shallow feature extraction 网络, 修改为2个3x3的卷积#### + self.conv_first1_A = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first1_B = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first2_A = nn.Conv2d(embed_dim_temp, embed_dim, 3, 1, 1) + self.conv_first2_B = nn.Conv2d(embed_dim_temp, embed_dim_temp, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + self.Re_num_layers = len(Re_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + dpr_Re = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Re_depths))] # stochastic depth decay rule + # build Residual Swin Transformer blocks (RSTB) + self.layers_Ex_A = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_A.append(layer) + self.norm_Ex_A = norm_layer(self.num_features) + + self.layers_Ex_B = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_B.append(layer) + self.norm_Ex_B = norm_layer(self.num_features) + + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + self.layers_Re = nn.ModuleList() + for i_layer in range(self.Re_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Re_depths[i_layer], + num_heads=Re_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Re[sum(Re_depths[:i_layer]):sum(Re_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Re.append(layer) + self.norm_Re = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last1 = nn.Conv2d(embed_dim, embed_dim_temp, 3, 1, 1) + self.conv_last2 = nn.Conv2d(embed_dim_temp, int(embed_dim_temp/2), 3, 1, 1) + self.conv_last3 = nn.Conv2d(int(embed_dim_temp/2), num_out_ch, 3, 1, 1) + + self.tanh = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features_Ex_A(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_A: + x = layer(x, x_size) + + x = self.norm_Ex_A(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward_features_Ex_B(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_B: + x = layer(x, x_size) + + x = self.norm_Ex_B(x) # B L C + x = self.patch_unembed(x, x_size) + return x + + def forward_features_Fusion(self, x, y): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + # print("x token:", x.shape, "y token:", y.shape) + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + x = torch.cat([x, y], 1) + ## Downsample the feature in the channel dimension + x = self.lrelu(self.conv_after_body_Fusion(x)) + + return x + + def forward_features_Re(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Re: + x = layer(x, x_size) + + x = self.norm_Re(x) # B L C + x = self.patch_unembed(x, x_size) + # Convolution + x = self.lrelu(self.conv_last1(x)) + x = self.lrelu(self.conv_last2(x)) + x = self.conv_last3(x) + return x + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + # self.mean_A = self.mean.type_as(x) + # self.mean_B = self.mean.type_as(y) + # self.mean = (self.mean_A + self.mean_B) / 2 + + # x = (x - self.mean_A) * self.img_range + # y = (y - self.mean_B) * self.img_range + + # Feedforward + x = self.forward_features_Ex_A(x) + y = self.forward_features_Ex_B(y) + # print("x before fusion:", x.shape) + # print("y before fusion:", y.shape) + x = self.forward_features_Fusion(x, y) + x = self.forward_features_Re(x) + x = self.tanh(x) + + # x = x / self.img_range + self.mean + # print("H:", H, "upscale:", self.upscale) + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='convs') + + +if __name__ == '__main__': + model = SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='convs') + # print(model) + + A = torch.randn((1, 1, 240, 240)) + B = torch.randn((1, 1, 240, 240)) + + x = model(A, B) + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/TransFuse.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/TransFuse.py new file mode 100644 index 0000000000000000000000000000000000000000..56ca3d2b24f382603c876e4f6909363a9794ffa3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/TransFuse.py @@ -0,0 +1,155 @@ +import math +from collections import deque + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from torchvision import models + + +class SelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + """ + + def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(n_embd, n_embd) + self.query = nn.Linear(n_embd, n_embd) + self.value = nn.Linear(n_embd, n_embd) + # regularization + self.attn_drop = nn.Dropout(attn_pdrop) + self.resid_drop = nn.Dropout(resid_pdrop) + # output projection + self.proj = nn.Linear(n_embd, n_embd) + self.n_head = n_head + + def forward(self, x): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y + + +class Block(nn.Module): + """ an unassuming Transformer block """ + + def __init__(self, n_embd, n_head, block_exp, attn_pdrop, resid_pdrop): + super().__init__() + self.ln1 = nn.LayerNorm(n_embd) + self.ln2 = nn.LayerNorm(n_embd) + self.attn = SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + self.mlp = nn.Sequential( + nn.Linear(n_embd, block_exp * n_embd), + nn.ReLU(True), # changed from GELU + nn.Linear(block_exp * n_embd, n_embd), + nn.Dropout(resid_pdrop), + ) + + def forward(self, x): + B, T, C = x.size() + + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + + return x + + +class TransFuse_layer(nn.Module): + """ the full GPT language model, with a context size of block_size """ + + def __init__(self, n_embd, n_head, block_exp, n_layer, + num_anchors, seq_len=1, + embd_pdrop=0.1, attn_pdrop=0.1, resid_pdrop=0.1): + super().__init__() + self.n_embd = n_embd + self.seq_len = seq_len + self.vert_anchors = num_anchors + self.horz_anchors = num_anchors + + # positional embedding parameter (learnable), image + lidar + self.pos_emb = nn.Parameter(torch.zeros(1, 2 * seq_len * self.vert_anchors * self.horz_anchors, n_embd)) + + self.drop = nn.Dropout(embd_pdrop) + + # transformer + self.blocks = nn.Sequential(*[Block(n_embd, n_head, + block_exp, attn_pdrop, resid_pdrop) + for layer in range(n_layer)]) + + # decoder head + self.ln_f = nn.LayerNorm(n_embd) + + self.block_size = seq_len + self.apply(self._init_weights) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + + def forward(self, m1, m2): + """ + Args: + m1 (tensor): B*seq_len, C, H, W + m2 (tensor): B*seq_len, C, H, W + """ + + bz = m2.shape[0] // self.seq_len + h, w = m2.shape[2:4] + + # forward the image model for token embeddings + m1 = m1.view(bz, self.seq_len, -1, h, w) + m2 = m2.view(bz, self.seq_len, -1, h, w) + + # pad token embeddings along number of tokens dimension + token_embeddings = torch.cat([m1, m2], dim=1).permute(0,1,3,4,2).contiguous() + token_embeddings = token_embeddings.view(bz, -1, self.n_embd) # (B, an * T, C) + + # add (learnable) positional embedding for all tokens + x = self.drop(self.pos_emb + token_embeddings) # (B, an * T, C) + x = self.blocks(x) # (B, an * T, C) + x = self.ln_f(x) # (B, an * T, C) + x = x.view(bz, 2 * self.seq_len, self.vert_anchors, self.horz_anchors, self.n_embd) + x = x.permute(0,1,4,2,3).contiguous() # same as token_embeddings + + m1_out = x[:, :self.seq_len, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + m2_out = x[:, self.seq_len:, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + print("modality1 output:", m1_out.max(), m1_out.min()) + print("modality2 output:", m2_out.max(), m2_out.min()) + + return m1_out, m2_out + + +if __name__ == "__main__": + + feature1 = torch.randn((4, 512, 20, 20)) + feature2 = torch.randn((4, 512, 20, 20)) + model = TransFuse_layer(n_embd=512, n_head=4, block_exp=4, n_layer=8, num_anchors=20, seq_len=1) + print("TransFuse_layer:", model) + feat1, feat2 = model(feature1, feature2) + print(feat1.shape, feat2.shape) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/Unet_ART.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/Unet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b0f834270a921ae4eb428c92b3e782fbed19b3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/Unet_ART.py @@ -0,0 +1,243 @@ +""" +2023/07/06, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + output = conv_layer(output) + + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/__init__.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..334f1821a9b59e692641f70cdc61e7655b1a9c89 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/__init__.py @@ -0,0 +1,114 @@ +from .unet import build_model as UNET +from .mmunet import build_model as MUNET +from .humus_net import build_model as HUMUS +from .MINet import build_model as MINET +from .SANet import build_model as SANET +from .mtrans_net import build_model as MTRANS +from .unimodal_transformer import build_model as TRANS +from .cnn_transformer_new import build_model as CNN_TRANS +from .cnn import build_model as CNN +from .unet_transformer import build_model as UNET_TRANSFORMER +from .swin_transformer import build_model as SWIN_TRANS +from .restormer import build_model as RESTORMER + + +from .cnn_late_fusion import build_model as MULTI_CNN +from .mmunet_late_fusion import build_model as MMUNET_LATE +from .mmunet_early_fusion import build_model as MMUNET_EARLY +from .trans_unet.trans_unet import build_model as TRANSUNET +from .SwinFusion import build_model as SWINMULTI +from .mUnet_transformer import build_model as MUNET_TRANSFORMER +from .munet_multi_transfuse import build_model as MUNET_TRANSFUSE +from .munet_swinfusion import build_model as MUNET_SWINFUSE +from .munet_multi_concat import build_model as MUNET_CONCAT +from .munet_concat_decomp import build_model as MUNET_CONCAT_DECOMP +from .munet_multi_sum import build_model as MUNET_SUM +# from .mUnet_ARTfusion_new import build_model as MUNET_ART_FUSION +from .mUnet_ARTfusion import build_model as MUNET_ART_FUSION +from .mUnet_Restormer_fusion import build_model as MUNET_RESTORMER_FUSION +from .mUnet_ARTfusion_SeqConcat import build_model as MUNET_ART_FUSION_SEQ + +from .ARTUnet import build_model as ART +from .Unet_ART import build_model as UNET_ART +from .unet_restormer import build_model as UNET_RESTORMER +from .mmunet_ART import build_model as MMUNET_ART +from .mmunet_restormer import build_model as MMUNET_RESTORMER +from .mmunet_restormer_v2 import build_model as MMUNET_RESTORMER_V2 +from .mmunet_restormer_3blocks import build_model as MMUNET_RESTORMER_SMALL + +from .mARTUnet import build_model as ART_MULTI_INPUT + +from .ART_Restormer import build_model as ART_RESTORMER +from .ART_Restormer_v2 import build_model as ART_RESTORMER_V2 +from .swinIR import build_model as SWINIR + +from .Kspace_ConvNet import build_model as KSPACE_CONVNET +from .Kspace_mUnet_new import build_model as KSPACE_MUNET +from .kspace_mUnet_concat import build_model as KSPACE_MUNET_CONCAT +from .kspace_mUnet_AttnFusion import build_model as KSPACE_MUNET_ATTNFUSE + +from .DuDo_mUnet import build_model as DUDO_MUNET +from .DuDo_mUnet_ARTfusion import build_model as DUDO_MUNET_ARTFUSION +from .DuDo_mUnet_CatFusion import build_model as DUDO_MUNET_CONCAT + + +model_factory = { + 'unet_single': UNET, + 'humus_single': HUMUS, + 'transformer_single': TRANS, + 'cnn_transformer': CNN_TRANS, + 'cnn_single': CNN, + 'swin_trans_single': SWIN_TRANS, + 'trans_unet': TRANSUNET, + 'unet_transformer': UNET_TRANSFORMER, + 'restormer': RESTORMER, + 'unet_art': UNET_ART, + 'unet_restormer': UNET_RESTORMER, + 'art': ART, + 'art_restormer': ART_RESTORMER, + 'art_restormer_v2': ART_RESTORMER_V2, + 'swinIR': SWINIR, + + + 'munet_transformer': MUNET_TRANSFORMER, + 'munet_transfuse': MUNET_TRANSFUSE, + 'cnn_late_multi': MULTI_CNN, + 'unet_multi': MUNET, + 'unet_late_multi':MMUNET_LATE, + 'unet_early_multi':MMUNET_EARLY, + 'munet_ARTfusion': MUNET_ART_FUSION, + 'munet_restormer_fusion': MUNET_RESTORMER_FUSION, + + + 'minet_multi': MINET, + 'sanet_multi': SANET, + 'mtrans_multi': MTRANS, + 'swin_fusion': SWINMULTI, + 'munet_swinfuse': MUNET_SWINFUSE, + 'munet_concat': MUNET_CONCAT, + 'munet_concat_decomp': MUNET_CONCAT_DECOMP, + 'munet_sum': MUNET_SUM, + 'mmunet_art': MMUNET_ART, + 'mmunet_restormer': MMUNET_RESTORMER, + 'mmunet_restormer_small': MMUNET_RESTORMER_SMALL, + 'mmunet_restormer_v2': MMUNET_RESTORMER_V2, + + 'art_multi_input': ART_MULTI_INPUT, + + 'munet_ARTfusion_SeqConcat': MUNET_ART_FUSION_SEQ, + + 'kspace_ConvNet': KSPACE_CONVNET, + 'kspace_munet': KSPACE_MUNET, + 'kspace_munet_concat': KSPACE_MUNET_CONCAT, + 'kspace_munet_AttnFusion': KSPACE_MUNET_ATTNFUSE, + + 'dudo_munet': DUDO_MUNET, + 'dudo_munet_ARTfusion': DUDO_MUNET_ARTFUSION, + 'dudo_munet_concat': DUDO_MUNET_CONCAT, +} + + +def build_model_from_name(args): + assert args.model_name in model_factory.keys(), 'unknown model name' + + return model_factory[args.model_name](args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/cnn.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5c130178e7cdabaaef8686da0cec341265f3c43d --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/cnn.py @@ -0,0 +1,246 @@ +""" +把our method中的单模态CNN_Transformer中的transformer换成几个conv_block用来下采样提取特征, 看是否能够达到和Unet相似的效果。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Reconstructor(nn.Module): + def __init__(self): + super(CNN_Reconstructor, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + nlatent = self.encoder.output_dim + self.decoder = Decoder(n_upsample=4, n_res=1, dim=nlatent, output_dim=1, pad_type='zero') + + + def forward(self, m): + #4,1,240,240 + feature_output = self.encoder.model(m) # [4, 256, 60, 60] + # print("output of the encoder:", feature_output.shape) + + output = self.decoder(feature_output) + + return output + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Reconstructor() diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/cnn_late_fusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/cnn_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..58dacad6910b6e896d00a0d36c9ea60e8d77217f --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/cnn_late_fusion.py @@ -0,0 +1,196 @@ +""" +在lequan的代码上,去掉transformer-based feature fusion部分。 +直接用CNN提取特征之后concat作为multi-modal representation. 用普通的decoder进行图像重建。 +把T1作为guidance modality, T2作为target modality. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers1 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + self.down_sample_layers2 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers1.append(ConvBlock(ch, ch * 2, drop_prob)) + self.down_sample_layers2.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + + # self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # ch = chans + # for _ in range(num_pool_layers - 1): + # self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + # ch *= 2 + # self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv = nn.Sequential(nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), nn.Tanh()) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = torch.cat([image, aux_image], 1) + # for layer in self.down_sample_layers: + # output = layer(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + + # apply down-sampling layers + output = image + for layer in self.down_sample_layers1: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + output_m1 = output + + output = aux_image + for layer in self.down_sample_layers2: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + output_m2 = output + + # print("two modality outputs:", output_m1.shape, output.shape) + output = torch.cat([output_m1, output_m2], 1) + + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv in self.up_transpose_conv: + output = transpose_conv(output) + + output = self.up_conv(output) + # print("model output:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/cnn_transformer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/cnn_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..261d4d214e35c3a9be5d28bc426ce7280124723c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/cnn_transformer.py @@ -0,0 +1,289 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=2, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=2, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 60 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 4 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + feature_output = self.upsampling1(feature_output) # [4,512,30,30] + feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/cnn_transformer_new.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/cnn_transformer_new.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b2dea1e7cc8d456a684b6c839f052383999e84 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/cnn_transformer_new.py @@ -0,0 +1,290 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +2023/05/23: 之前是从尺寸为60*60的特征图上切patch, 现在可以尝试直接用conv_blocks提取15*15的特征图,然后按照patch_size = 1切成225个patches。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=4, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 1 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + # feature_output = self.upsampling1(feature_output) # [4,512,30,30] + # feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/humus_net.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/humus_net.py new file mode 100644 index 0000000000000000000000000000000000000000..d23701634f849b414866d71b3c23c6d07f3d6437 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/humus_net.py @@ -0,0 +1,1070 @@ +""" +HUMUS-Block +Hybrid Transformer-convolutional Multi-scale denoiser. + +We use parts of the code from +SwinIR (https://github.com/JingyunLiang/SwinIR/blob/main/models/network_swinir.py) +Swin-Unet (https://github.com/HuCaoFighting/Swin-Unet/blob/main/networks/swin_transformer_unet_skip_expand_decoder_sys.py) + +HUMUS-Net中是对kspace data操作, 多个unrolling cascade, 每个cascade中都利用HUMUS-block对图像进行denoising. +因为我们的方法输入是Under-sampled images, 所以不需要进行unrolling, 直接使用一个HUMUS-block进行MRI重建。 +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from einops import rearrange + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # print("x shape:", x.shape) + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + +class PatchExpandSkip(nn.Module): + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.expand = nn.Linear(dim, 2*dim, bias=False) + self.channel_mix = nn.Linear(dim, dim // 2, bias=False) + self.norm = norm_layer(dim // 2) + + def forward(self, x, skip): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) + x = x.view(B,-1,C//4) + x = torch.cat([x, skip], dim=-1) + x = self.channel_mix(x) + x = self.norm(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, torch.tensor(x_size)) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + block_type: + D: downsampling block, + U: upsampling block + B: bottleneck block + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv', block_type='B'): + super(RSTB, self).__init__() + self.dim = dim + conv_dim = dim // (patch_size ** 2) + divide_out_ch = 1 + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + ) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(conv_dim, conv_dim // divide_out_ch, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(conv_dim, conv_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // divide_out_ch, 3, 1, 1)) + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.block_type = block_type + if block_type == 'B': + self.reshape = torch.nn.Identity() + elif block_type == 'D': + self.reshape = PatchMerging(input_resolution, dim) + elif block_type == 'U': + self.reshape = PatchExpandSkip([res // 2 for res in input_resolution], dim * 2) + else: + raise ValueError('Unknown RSTB block type.') + + def forward(self, x, x_size, skip=None): + if self.block_type == 'U': + assert skip is not None, "Skip connection is required for patch expand" + x = self.reshape(x, skip) + + out = self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size)) + block_out = self.patch_embed(out) + x + + if self.block_type == 'D': + block_out = (self.reshape(block_out), block_out) # return skip connection + + return block_out + +class PatchEmbed(nn.Module): + r""" Fixed Image to Patch Embedding. Flattens image along spatial dimensions + without learned projection mapping. Only supports 1x1 patches. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + embed_dim (int): Number of output channels. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + +class PatchEmbedLearned(nn.Module): + r""" Image to Patch Embedding with arbitrary patch size and learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Linear(in_chans * patch_size[0] * patch_size[1], embed_dim) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.patch_to_vec(x) + x = self.proj(x) + if self.norm is not None: + x = self.norm(x) + return x + + def patch_to_vec(self, x): + """ + Args: + x: (B, C, H, W) + Returns: + patches: (B, num_patches, patch_size * patch_size * C) + """ + B, C, H, W = x.shape + x = x.permute(0, 2, 3, 1).contiguous() + x = x.view(B, H // self.patch_size[0], self.patch_size[0], W // self.patch_size[1], self.patch_size[1], C) + patches = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.patch_size[0], self.patch_size[1], C) + patches = patches.contiguous().view(B, self.num_patches, self.patch_size[0] * self.patch_size[1] * C) + return patches + + +class PatchUnEmbed(nn.Module): + r""" Fixed Image to Patch Unembedding via reshaping. Only supports patch size 1x1. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + +class PatchUnEmbedLearned(nn.Module): + r""" Patch Embedding to Image with learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans=None, out_chans=None, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.out_chans = out_chans + + self.proj_up = nn.Linear(in_chans, in_chans * patch_size[0] * patch_size[1]) + + def forward(self, x, x_size=None): + B, HW, C = x.shape + x = self.proj_up(x) + x = self.vec_to_patch(x) + return x + + def vec_to_patch(self, x): + B, HW, C = x.shape + x = x.view(B, self.patches_resolution[0], self.patches_resolution[1], C).contiguous() + x = x.view(B, self.patches_resolution[0], self.patch_size[0], self.patches_resolution[1], self.patch_size[1], C // (self.patch_size[0] * self.patch_size[1])).contiguous() + x = x.view(B, self.img_size[0], self.img_size[1], -1).contiguous().permute(0, 3, 1, 2) + return x + + +class HUMUSNet(nn.Module): + r""" HUMUS-Block + Args: + img_size (int | tuple(int)): Input image size. + patch_size (int | tuple(int)): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depths (tuple(int)): Depth of each Swin Transformer layer in encoder and decoder paths. + num_heads (tuple(int)): Number of attention heads in different layers of encoder and decoder. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + ape (bool): If True, add absolute position embedding to the patch embedding. + patch_norm (bool): If True, add normalization after patch embedding. + use_checkpoint (bool): Whether to use checkpointing to save memory. + img_range: Image range. 1. or 255. + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + conv_downsample_first: use convolutional downsampling before MUST to reduce compute load on Transformers + """ + + def __init__(self, args, + img_size=[240, 240], + in_chans=1, + patch_size=1, + embed_dim=66, + depths=[2, 2, 2], + num_heads=[3, 6, 12], + window_size=5, + mlp_ratio=2., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + img_range=1., + resi_connection='1conv', + bottleneck_depth=2, + bottleneck_heads=24, + conv_downsample_first=True, + out_chans=1, + no_residual_learning=False, + **kwargs): + super(HUMUSNet, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans if out_chans is None else out_chans + + self.center_slice_out = (out_chans == 1) + + self.img_range = img_range + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + self.conv_downsample_first = conv_downsample_first + self.no_residual_learning = no_residual_learning + + ##################################################################################################### + ################################### 1, input block ################################### + input_conv_dim = embed_dim + self.conv_first = nn.Conv2d(num_in_ch, input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** self.num_layers) + self.mlp_ratio = mlp_ratio + + # Downsample for low-res feature extraction + if self.conv_downsample_first: + img_size = [im //2 for im in img_size] + self.conv_down_block = ConvBlock(input_conv_dim // 2, input_conv_dim, 0.0) + self.conv_down = DownsampConvBlock(input_conv_dim, input_conv_dim) + + # split image into non-overlapping patches + if patch_size > 1: + self.patch_embed = PatchEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + if patch_size > 1: + self.patch_unembed = PatchUnEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, out_chans=embed_dim, norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # Build MUST + # encoder + self.layers_down = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** i_layer) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='D', + ) + self.layers_down.append(layer) + + # bottleneck + dim_scaler = (2 ** self.num_layers) + self.layer_bottleneck = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=bottleneck_depth, + num_heads=bottleneck_heads, + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='B', + ) + + # decoder + self.layers_up = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** (self.num_layers - i_layer - 1)) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[(self.num_layers-1-i_layer)], + num_heads=num_heads[(self.num_layers-1-i_layer)], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='U', + ) + self.layers_up.append(layer) + + self.norm_down = norm_layer(self.num_features) + self.norm_up = norm_layer(self.embed_dim) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(input_conv_dim, input_conv_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(input_conv_dim, input_conv_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim, 3, 1, 1)) + + # Upsample if needed + if self.conv_downsample_first: + self.conv_up_block = ConvBlock(input_conv_dim, input_conv_dim // 2, 0.0) + self.conv_up = TransposeConvBlock(input_conv_dim, input_conv_dim // 2) + + ##################################################################################################### + ################################ 3, output block ################################ + self.conv_last = nn.Conv2d(input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, num_out_ch, 3, 1, 1) + self.output_norm = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + + # divisible by window size + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + + # divisible by total downsampling ratio + # this could be done more efficiently by combining the two + total_downsamp = int(2 ** (self.num_layers - 1)) + pad_h = (total_downsamp - h % total_downsamp) % total_downsamp + pad_w = (total_downsamp - w % total_downsamp) % total_downsamp + x = F.pad(x, (0, pad_w, 0, pad_h)) + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + # encode + skip_cons = [] + for i, layer in enumerate(self.layers_down): + x, skip = layer(x, x_size=layer.input_resolution) + skip_cons.append(skip) + x = self.norm_down(x) # B L C + + # bottleneck + x = self.layer_bottleneck(x, self.layer_bottleneck.input_resolution) + + # decode + for i, layer in enumerate(self.layers_up): + x = layer(x, x_size=layer.input_resolution, skip=skip_cons[-i-1]) + x = self.norm_up(x) + x = self.patch_unembed(x, x_size) + return x + + def forward(self, x): + # print("input shape:", x.shape) + # print("input range:", x.max(), x.min()) + C, H, W = x.shape[1:] + center_slice = (C - 1) // 2 + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.conv_downsample_first: + x_first = self.conv_first(x) + x_down = self.conv_down(self.conv_down_block(x_first)) + res = self.conv_after_body(self.forward_features(x_down)) + res = self.conv_up(res) + res = torch.cat([res, x_first], dim=1) + res = self.conv_up_block(res) + + res = self.conv_last(res) + + if self.no_residual_learning: + x = res + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + res + else: + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + + if self.no_residual_learning: + x = self.conv_last(res) + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + self.conv_last(res) + + # print("output range before scaling:", x.max(), x.min()) + # print("self.mean:", self.mean) + + if self.center_slice_out: + x = x / self.img_range + self.mean[:, center_slice, ...].unsqueeze(1) + else: + x = x / self.img_range + self.mean + + x = self.output_norm(x) + + return x[:, :, :H, :W] + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + +class DownsampConvBlock(nn.Module): + """ + A Downsampling Convolutional Block that consists of one strided convolution + layer followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.Conv2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H/2, W/2)`. + """ + return self.layers(image) + + + +def build_model(args): + return HUMUSNet(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/kspace_mUnet_AttnFusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/kspace_mUnet_AttnFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..02588c24559c0604ddbbd03a14b3b32e6cff2c8d --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/kspace_mUnet_AttnFusion.py @@ -0,0 +1,300 @@ +""" +Xiaohan Xing, 2023/11/22 +两个模态分别用CNN提取多个层级的特征, 每个层级都把特征沿着宽度拼接起来,得到(h, 2w, c)的特征。 +然后学习得到(h, 2w)的attention map, 和原始的特征相乘送到后面的层。 +""" + +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.attn_layers = nn.ModuleList() + + for l in range(self.num_pool_layers): + + self.attn_layers.append(nn.Sequential(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob), + nn.Conv2d(chans*(2**l), 2, kernel_size=1, stride=1), + nn.Sigmoid())) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, attn_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.attn_layers): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat之后求attention, 然后进行modulation. + attn = attn_layer(torch.cat((output1, output2), 1)) + # print("spatial attention:", attn[:, 0, :, :].unsqueeze(1).shape) + # print("output1:", output1.shape) + output1 = output1 * (1 + attn[:, 0, :, :].unsqueeze(1)) + output2 = output2 * (1 + attn[:, 1, :, :].unsqueeze(1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/kspace_mUnet_concat.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/kspace_mUnet_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..d16bfdc83305d3e1f490e5ca4b445c6d730557ce --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/kspace_mUnet_concat.py @@ -0,0 +1,351 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Predict the low-frequency mask ############## +class Mask_Predictor(nn.Module): + def __init__(self, in_chans, out_chans, chans, drop_prob): + super(Mask_Predictor, self).__init__() + self.conv1 = ConvBlock(in_chans, chans, drop_prob) + self.conv2 = ConvBlock(chans, chans*2, drop_prob) + self.conv3 = ConvBlock(chans*2, chans*4, drop_prob) + + # self.conv4 = ConvBlock(chans*4, 1, drop_prob) + + self.FC = nn.Linear(chans*4, out_chans) + self.act = nn.Sigmoid() + + + def forward(self, x): + ### 三层卷积和down-sampling. + # print("[Mask predictor] input data range:", x.max(), x.min()) + output = F.avg_pool2d(self.conv1(x), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer1_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv2(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer2_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv3(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer3_feature range:", output.max(), output.min()) + + feature = F.adaptive_avg_pool2d(F.relu(output), (1, 1)).squeeze(-1).squeeze(-1) + # print("mask_prediction feature:", output[:, :5]) + # print("[Mask predictor] bottleneck feature range:", feature.max(), feature.min()) + + output = self.FC(feature) + # print("output before act:", output) + output = self.act(output) + # print("mask output:", output) + + return output + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.mask_predictor1 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + self.mask_predictor2 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + ### 预测做low-freq DC的mask区域. + ### 输出三个维度, 前两维是坐标,最后一维是mask内部的权重 + t1_mask = self.mask_predictor1(output1) + t2_mask = self.mask_predictor2(output2) + + # print("t1_mask and t2_mask:", t1_mask.shape, t2_mask.shape) + t1_mask_coords, t2_mask_coords = t1_mask[:, :2], t2_mask[:, :2] + t1_DC_weight, t2_DC_weight = t1_mask[:, -1], t2_mask[:, -1] + + + # ### 预测做low-freq DC的DC weight + # t1_DC_weight = self.mask_predictor1(output1) + # t2_DC_weight = self.mask_predictor2(output2) + + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, t1_mask_coords, t2_mask_coords, t1_DC_weight, t2_DC_weight ## + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mARTUnet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9607d41ca14562a94e542a1804af5aeaf15bc123 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mARTUnet.py @@ -0,0 +1,563 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +2023/07/20, Xiaohan Xing +将两个模态的图像在输入端concat, 得到num_channels = 2的input, 然后送到ART模型进行T2 modality的重建。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + # print("feature after layer_norm:", x.max(), x.min()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=2, + out_channels=1, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img, aux_img): + stack = [] + bs, _, H, W = inp_img.shape + fuse_img = torch.cat((inp_img, aux_img), 1) + inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=2, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mUnet_ARTfusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..610b1e1e38d74f695a3a19f22a6a80f3152bf713 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mUnet_ARTfusion.py @@ -0,0 +1,305 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import CS_TransformerBlock as TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + # output1, output2 = image1, image2 + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mUnet_ARTfusion_SeqConcat.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mUnet_ARTfusion_SeqConcat.py new file mode 100644 index 0000000000000000000000000000000000000000..55280925415f1db47cf746b59fb9710d47c9bff2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mUnet_ARTfusion_SeqConcat.py @@ -0,0 +1,374 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any + +import matplotlib.pyplot as plt +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet_new import ConcatTransformerBlock, TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion_SeqConcat(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4. + ): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.vis_feature_maps = False + self.vis_attention_maps = False + # self.vis_feature_maps = args.MODEL.VIS_FEAT_MAPS + # self.vis_attention_maps = args.MODEL.VIS_ATTN_MAPS + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + if l < 2: + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * (2 ** (l + 1))), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + else: + self.fuse_layers.append(nn.ModuleList([ConcatTransformerBlock(dim=int(chans * (2 ** l)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + feature_maps = { + 'modal1': {'before': [], 'after': []}, + 'modal2': {'before': [], 'after': []} + } + + attention_maps = [] + """ + attention_maps = [ + unet layer 1 attention maps (x2): [map from TransBlock1, TransBlock2, ...], + unet layer 2 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 3 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 4 attention maps (x6): [map from TransBlock1, TransBlock2, ...], + ] + """ + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + if self.vis_feature_maps: + feature_maps['modal1']['before'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['before'].append(output2.detach().cpu().numpy()) + + ### 两个模态的特征concat + # output = torch.cat((output1, output2), 1) + + + + unet_layer_attn_maps = [] + """ + unet_layer_attn_maps: [ + attn_map from TransformerBlock1: (B, nHGroups, nWGroups, winSize, winSize), + attn_map from TransformerBlock2: (B, nHGroups, nWGroups, winSize, winSize), + ... + ] + """ + if l < 2: + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + else: + output1 = rearrange(output1, "b c h w -> b (h w) c").contiguous() + output2 = rearrange(output2, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + if l < 2: + if self.vis_attention_maps: + output, attn_map = layer(output, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output = layer(output, [H // (2**l), W // (2**l)]) + + else: + if self.vis_attention_maps: + output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + attention_maps.append(unet_layer_attn_maps) + + if l < 2: + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + else: + output1 = rearrange(output1, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output2 = rearrange(output2, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + # if self.vis_attention_maps: + # output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) # attn_map: (B, nHGroups, nWGroups, winSize, winSize) + # unet_layer_attn_maps.append(attn_map) + # else: + # output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + # attention_maps.append(unet_layer_attn_maps) + + + + if self.vis_feature_maps: + feature_maps['modal1']['after'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['after'].append(output2.detach().cpu().numpy()) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + if self.vis_feature_maps: + return output1, output2, feature_maps#, relation_stack1, relation_stack2 #, t1_features, t2_features + elif self.vis_attention_maps: + return output1, output2, attention_maps + else: + return output1, output2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion_SeqConcat(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mUnet_Restormer_fusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mUnet_Restormer_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8ab29b641fa0231ad0d163224a4a90b44c32a9 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mUnet_Restormer_fusion.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .restormer import TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8]): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.Sequential(*[TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[0], ffn_expansion_factor=2.66, \ + bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + ### 将encoder中multi-modal fusion之后的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = fuse_layer(output) + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mUnet_transformer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mUnet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2a209758203b198502154b75496333ef91717a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mUnet_transformer.py @@ -0,0 +1,387 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/06/08 +分别用CNN提取两个模态的特征,然后送到transformer进行特征交互, 将经过transformer提取的特征和之前CNN提取的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Decoder_wo_skip(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder_wo_skip, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + output = transpose_conv(output) + output = conv(output) + + return output + + + + + +class mUnet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch*2 + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1, output2 = image1, image2 + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack1, output1 = self.encoder1(output1) ### output size: [4, 256, 15, 15] + stack2, output2 = self.encoder2(output2) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output1) # [4, 225, 256], 225 is the number of patches. + patch_embed_input1 = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + n_embed = self.to_patch_embedding(output2) # [4, 225, 256], 225 is the number of patches. + patch_embed_input2 = F.normalize(n_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input1, patch_embed_input2), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1]/2)), int(np.sqrt(feature_output.shape[1]/2)) + patch_num = feature_output.shape[1] + feature_output1 = feature_output[:, :patch_num//2, :] + feature_output2 = feature_output[:, patch_num//2:, :] + feature_output1 = feature_output1.contiguous().view(b, feature_output1.shape[-1], h, w) # [4,c,15,15] + feature_output2 = feature_output2.contiguous().view(b, feature_output2.shape[-1], h, w) + + output = torch.cat((output1, output2), 1) + torch.cat((feature_output1, feature_output2), 1) + # print("CNN feature range:", output1.max(), output1.min()) + # print("Transformer feature range:", feature_output1.max(), feature_output1.min()) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_Transformer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ee0c4784d04638df4ca259613777dde34980a2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet.py @@ -0,0 +1,201 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/04 +Concatenate the input images from different modalities and input it into the UNet. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_ART.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..1a09f2900a8924cdc1a7c373270e6ff8aadcfa45 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_ART.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_ART_v2.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_ART_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..adfe6bc0d9192a126f39932b40443badb5b60068 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_ART_v2.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_early_fusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_early_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..d062f057a2080a471005125a00ee27453a4b0706 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_early_fusion.py @@ -0,0 +1,223 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +encapsulate the encoder and decoder for the Unet model. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + stack, output = self.encoder(output) + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + output = self.decoder(output, stack) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_late_fusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..2b48e7d4445afd6a5bc1eaa2b065c37431585e94 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_late_fusion.py @@ -0,0 +1,229 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +use Unet for the reconstruction of each modality, concat the intermediate features of both modalities. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("decoder output:", output.shape) + + return output + + + + +class mmUnet_late(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack1, output1 = self.encoder1(image) + stack2, output2 = self.encoder2(aux_image) + ### fuse two modalities to get multi-modal representation. + output = torch.cat([output1, output2], 1) + output = self.conv(output) ### from [4, 512, 15, 15] to [4, 512, 15, 15] + + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet_late(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_restormer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..88c7920d65c58ab82a00309a1ae501a8d5b33804 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_restormer.py @@ -0,0 +1,237 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_restormer_3blocks.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_restormer_3blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e3afe68612727d3195872a121ec8040fb0f66ef2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_restormer_3blocks.py @@ -0,0 +1,235 @@ +""" +2023/07/18, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +之前 3 blocks的模型深层特征存在严重的gradient vanishing问题。现在把模型的encoder和decoder都改成只有3个blocks, 看是否还存在gradient vanishing问题。 +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 3, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2] + heads=[1, 2, 4] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_restormer_v2.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..4b51d3a218057206863230973d557173e891d3cd --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mmunet_restormer_v2.py @@ -0,0 +1,220 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +这个版本中是将Unet encoder提取的特征直接连接到decoder, 经过restormer refine的特征只是传递到下一层的encoder block. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + output = feature + + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mtrans_mca.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mtrans_mca.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4c89df4af71e986fa7c24d522e6732371f3128 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mtrans_mca.py @@ -0,0 +1,119 @@ +import torch +from torch import nn + +from .mtrans_transformer import Mlp + +# implement by YunluYan + + +class LinearProject(nn.Module): + def __init__(self, dim_in, dim_out, fn): + super(LinearProject, self).__init__() + need_projection = dim_in != dim_out + self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity() + self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity() + self.fn = fn + + def forward(self, x, *args, **kwargs): + x = self.project_in(x) + x = self.fn(x, *args, **kwargs) + x = self.project_out(x) + return x + + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, dim, num_heads, attn_drop=0., proj_drop=0.): + super(MultiHeadCrossAttention, self).__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim, dim*2, bias=False) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + + def forward(self, x, complement): + + # x [B, HW, C] + B_x, N_x, C_x = x.shape + + x_copy = x + + complement = torch.cat([x, complement], 1) + + B_c, N_c, C_c = complement.shape + + # q [B, heads, HW, C//num_heads] + q = self.to_q(x).reshape(B_x, N_x, self.num_heads, C_x//self.num_heads).permute(0, 2, 1, 3) + kv = self.to_kv(complement).reshape(B_c, N_c, 2, self.num_heads, C_c//self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_x, N_x, C_x) + + x = x + x_copy + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossTransformerEncoderLayer(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio = 1., attn_drop=0., proj_drop=0.,drop_path = 0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super(CrossTransformerEncoderLayer, self).__init__() + self.x_norm1 = norm_layer(dim) + self.c_norm1 = norm_layer(dim) + + self.attn = MultiHeadCrossAttention(dim, num_heads, attn_drop, proj_drop) + + self.x_norm2 = norm_layer(dim) + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop) + + self.drop1 = nn.Dropout(drop_path) + self.drop2 = nn.Dropout(drop_path) + + def forward(self, x, complement): + x = self.x_norm1(x) + complement = self.c_norm1(complement) + + x = x + self.drop1(self.attn(x, complement)) + x = x + self.drop2(self.mlp(self.x_norm2(x))) + return x + + + +class CrossTransformer(nn.Module): + def __init__(self, x_dim, c_dim, depth, num_heads, mlp_ratio =1., attn_drop=0., proj_drop=0., drop_path =0.): + super(CrossTransformer, self).__init__() + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + LinearProject(x_dim, c_dim, CrossTransformerEncoderLayer(c_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)), + LinearProject(c_dim, x_dim, CrossTransformerEncoderLayer(x_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)) + ])) + + def forward(self, x, complement): + for x_attn_complemnt, complement_attn_x in self.layers: + x = x_attn_complemnt(x, complement=complement) + x + complement = complement_attn_x(complement, complement=x) + complement + return x, complement + + + + + + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mtrans_net.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mtrans_net.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e80dadb45725603c7ff1da664a45533575db25 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mtrans_net.py @@ -0,0 +1,145 @@ +""" +This script implement the paper "Multi-Modal Transformer for Accelerated MR Imaging (TMI 2022)" +""" + + +import torch + +from torch import nn +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler +from .mtrans_mca import CrossTransformer + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(INPUT_DIM, HEAD_HIDDEN_DIM): + return ReconstructionHead(INPUT_DIM, HEAD_HIDDEN_DIM) + + +class CrossCMMT(nn.Module): + + def __init__(self, args): + super(CrossCMMT, self).__init__() + + INPUT_SIZE = 240 + INPUT_DIM = 1 # the channel of input + OUTPUT_DIM = 1 # the channel of output + HEAD_HIDDEN_DIM = 16 # the hidden dim of Head + TRANSFORMER_DEPTH = 4 # the depth of the transformer + TRANSFORMER_NUM_HEADS = 4 # the head's num of multi head attention + TRANSFORMER_MLP_RATIO = 3 # the MLP RATIO Of transformer + TRANSFORMER_EMBED_DIM = 256 # the EMBED DIM of transformer + P1 = 8 + P2 = 16 + CTDEPTH = 4 + + + self.head = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + self.head2 = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + + x_patch_dim = HEAD_HIDDEN_DIM * P1 ** 2 + x_num_patches = (INPUT_SIZE // P1) ** 2 + + complement_patch_dim = HEAD_HIDDEN_DIM * P2 ** 2 + complement_num_patches = (INPUT_SIZE // P2) ** 2 + + self.x_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P1, + p2=P1), + ) + + self.complement_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P2, + p2=P2), + ) + + self.x_pos_embedding = nn.Parameter(torch.randn(1, x_num_patches, x_patch_dim)) + self.complement_pos_embedding = nn.Parameter(torch.randn(1, complement_num_patches, complement_patch_dim)) + + ### + self.cross_transformer = CrossTransformer(x_patch_dim, complement_patch_dim, CTDEPTH, + TRANSFORMER_NUM_HEADS, TRANSFORMER_MLP_RATIO) + + self.p1 = P1 + self.p2 = P2 + + self.tail1 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + self.tail2 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x, complement): + x = self.head(x) + complement = self.head2(complement) + + x = x + complement + + b, _, h, w = x.shape + + _, _, c_h, c_w = complement.shape + + ### 得到每个模态的patch embedding (加上了position encoding) + x = self.x_patch_embbeding(x) + x += self.x_pos_embedding + + complement = self.complement_patch_embbeding(complement) + complement += self.complement_pos_embedding + + ### 将两个模态的patch embeddings送到cross transformer中提取特征。 + x, complement = self.cross_transformer(x, complement) + + c = int(x.shape[2] / (self.p1 * self.p1)) + H = int(h / self.p1) + W = int(w / self.p1) + + x = x.reshape(b, H, W, self.p1, self.p1, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w, ) + x = self.tail1(x) + + complement = complement.reshape(b, int(c_h/self.p2), int(c_w/self.p2), self.p2, self.p2, int(complement.shape[2]/self.p2/self.p2)) + complement = complement.permute(0, 5, 1, 3, 2, 4) + complement = complement.reshape(b, -1, c_h, c_w) + + complement = self.tail2(complement) + + return x, complement + + +def build_model(args): + return CrossCMMT(args) + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mtrans_transformer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mtrans_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..61314a6a81247b0740a1e98d078242b26ce73f90 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/mtrans_transformer.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(args, embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=args.MODEL.TRANSFORMER_DEPTH, + num_heads=args.MODEL.TRANSFORMER_NUM_HEADS, + mlp_ratio=args.MODEL.TRANSFORMER_MLP_RATIO) + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/munet_concat_decomp.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/munet_concat_decomp.py new file mode 100644 index 0000000000000000000000000000000000000000..07c6f64804dd586927814801b167163f0b65f58f --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/munet_concat_decomp.py @@ -0,0 +1,295 @@ +""" +Xiaohan Xing, 2023/06/25 +两个模态分别用CNN提取多个层级的特征, 每个层级的特征都经过concat之后得到fused feature, +然后每个模态都通过一层conv变换得到2C*h*w的特征, 其中C个channels作为common feature, 另C个channels作为specific features. +利用CDDFuse论文中的decomposition loss约束特征解耦。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.transform_layers1 = nn.ModuleList() + self.transform_layers2 = nn.ModuleList() + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.transform_layers1.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.transform_layers2.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + stack1_common, stack1_specific, stack2_common, stack2_specific = [], [], [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_transform_layer, net2_transform_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.transform_layers1, self.transform_layers2, \ + self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + # print("original feature:", output1.shape) + + ### 分别提取两个模态的common feature和specific feature + output1 = net1_transform_layer(output1) + # print("transformed feature:", output1.shape) + output1_common = output1[:, :(output1.shape[1]//2), :, :] + output1_specific = output1[:, (output1.shape[1]//2):, :, :] + output2 = net2_transform_layer(output2) + output2_common = output2[:, :(output2.shape[1]//2), :, :] + output2_specific = output2[:, (output2.shape[1]//2):, :, :] + + stack1_common.append(output1_common) + stack1_specific.append(output1_specific) + stack2_common.append(output2_common) + stack2_specific.append(output2_specific) + + ### 将一个模态的两组feature和另一模态的common feature合并, 然后经过conv变换得到和初始特征相同维度的fused feature送到下一层。 + output1 = net1_fuse_layer(torch.cat((output1_common, output1_specific, output2_common), 1)) + output2 = net2_fuse_layer(torch.cat((output2_common, output2_specific, output1_common), 1)) + # print("fused feature:", output1.shape) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, stack1_common, stack1_specific, stack2_common, stack2_specific + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/munet_multi_concat.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/munet_multi_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..9e90373e1134210809b98fdc64e3d51d76123855 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/munet_multi_concat.py @@ -0,0 +1,306 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + # output1, output2 = image1, image2 + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # relation1 = get_relation_matrix(output1) + # relation2 = get_relation_matrix(output2) + # relation_stack1.append(relation1) + # relation_stack2.append(relation2) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + # output1 = net1_layer(output1) + # output2 = net2_layer(output2) + + # ### 两个模态的特征concat + # fused_output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + # fused_output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # stack1.append(fused_output1) + # stack2.append(fused_output2) + + # ### downsampling features + # output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + # output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/munet_multi_sum.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/munet_multi_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb4efb13b54ccbeaeb34fc51d3e72b0c50ee242 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/munet_multi_sum.py @@ -0,0 +1,270 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers): + + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + + ## 两个模态的特征求平均. + fused_output1 = (output1 + output2)/2.0 + fused_output2 = (output1 + output2)/2.0 + + output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features #, relation_stack1, relation_stack2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/munet_multi_transfuse.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/munet_multi_transfuse.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ce250a078ba0847c729c1ef1b76452f64d0f07 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/munet_multi_transfuse.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +2023/06/23 +每个层级的特征用TransFuse layer融合之后, 将两个模态的original features和transfuse features concat, +然后在每个模态中经过conv层变换得到下一层的输入。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .TransFuse import TransFuse_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_TransFuse(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + ### conv layer to transform the features after Transfuse layer. + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + self.n_anchors = 15 + self.avgpool = nn.AdaptiveAvgPool2d((self.n_anchors, self.n_anchors)) + self.transfuse_layer1 = TransFuse_layer(n_embd=self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer2 = TransFuse_layer(n_embd=2*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer3 = TransFuse_layer(n_embd=4*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer4 = TransFuse_layer(n_embd=8*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + + self.transfuse_layers = nn.Sequential(self.transfuse_layer1, self.transfuse_layer2, self.transfuse_layer3, self.transfuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, transfuse_layer, net1_fuse_layer, net2_fuse_layer in zip(self.encoder1.down_sample_layers, \ + self.encoder2.down_sample_layers, self.transfuse_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### extract multi-level features from the encoder of each modality. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + ### Transformer-based multi-modal fusion layer + feature1 = self.avgpool(output1) + feature2 = self.avgpool(output2) + feature1, feature2 = transfuse_layer(feature1, feature2) + feature1 = F.interpolate(feature1, scale_factor=output1.shape[-1]//self.n_anchors, mode='bilinear') + feature2 = F.interpolate(feature2, scale_factor=output2.shape[-1]//self.n_anchors, mode='bilinear') + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_TransFuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/munet_swinfusion.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/munet_swinfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6203c4331744f1da70f18e403360196a419b4cb5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/munet_swinfusion.py @@ -0,0 +1,268 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .SwinFuse_layer import SwinFusion_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_SwinFuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过SwinFusion的方式融合。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.fuse_type = 'swinfuse' + self.img_size = 240 + # self.fuse_type = 'sum' + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # Swin transformer fusion modules + + self.fuse_layer1 = SwinFusion_layer(img_size=(self.img_size//2, self.img_size//2), patch_size=4, window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer2 = SwinFusion_layer(img_size=(self.img_size//4, self.img_size//4), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=2*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer3 = SwinFusion_layer(img_size=(self.img_size//8, self.img_size//8), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=4*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer4 = SwinFusion_layer(img_size=(self.img_size//16, self.img_size//16), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=8*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.swinfusion_layers = nn.Sequential(self.fuse_layer1, self.fuse_layer2, self.fuse_layer3, self.fuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, swinfuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.swinfusion_layers): + output1 = net1_layer(output1) + stack1.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # print("output of encoders:", output1.shape, output2.shape) + # print("CNN feature range:", output1.max(), output2.min()) + + ### Swin transformer-based multi-modal fusion. + feature1, feature2 = swinfuse_layer(output1, output2) + output1 = (output1 + feature1 + output2 + feature2)/4.0 + output2 = (output1 + feature1 + output2 + feature2)/4.0 + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_SwinFuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/original_MINet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/original_MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..5a115872320f677d8b34264eed95c22fff0df9c4 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/original_MINet.py @@ -0,0 +1,335 @@ +from fastmri.models import common +import torch +import torch.nn as nn +import torch.nn.functional as F + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=common.default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 2 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + common.Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.add_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class MINet(nn.Module): + def __init__(self, n_resgroups,n_resblocks, n_feats): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1,x2#x1=pd x2=pdfs + \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/restormer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..3699cc809e438be4e1bd5efaba7d96e18d7f1602 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/restormer.py @@ -0,0 +1,307 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + # print("feature before FFN projection:", x.max(), x.min()) + x = self.project_out(x) + # print("FFN output:", x.max(), x.min()) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + b,c,h,w = x.shape + # print("Attention block input:", x.max(), x.min()) + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + # print("attention values:", attn) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("Attention block output:", out.max(), out.min()) + + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): + super(TransformerBlock, self).__init__() + + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + # dim = 32, + # num_blocks = [4,4,4,4], + dim = 32, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + + # stack = [] + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + # print("encoder block1 feature:", out_enc_level1.shape, out_enc_level1.max(), out_enc_level1.min()) + # stack.append(out_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + # print("encoder block2 feature:", out_enc_level2.shape, out_enc_level2.max(), out_enc_level2.min()) + # stack.append(out_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + # print("encoder block3 feature:", out_enc_level3.shape, out_enc_level3.max(), out_enc_level3.min()) + # stack.append(out_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + # print("encoder block4 feature:", latent.shape, latent.max(), latent.min()) + # stack.append(latent) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + + return out_dec_level1 #, stack + + + +def build_model(args): + return Restormer() diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/restormer_block.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/restormer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..170ae6a826e5a3e356cb48f8c89500c2d5242816 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/restormer_block.py @@ -0,0 +1,321 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + # print("input to the LayerNorm layer:", x.max(), x.min()) + h, w = x.shape[-2:] + x = to_4d(self.body(to_3d(x)), h, w) + # print("output of the LayerNorm:", x.max(), x.min()) + return x + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + # self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.temperature = 1.0 / ((dim//num_heads) ** 0.5) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + # print("input to the Attention layer:", x.max(), x.min()) + + b,c,h,w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + # print("q range:", q.max(), q.min()) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + # print("scaling parameter:", self.temperature) + # print("attention matrix before softmax:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + attn = attn.softmax(dim=-1) + # print("attention matrix range:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + # print("v range:", v.max(), v.min(), v.mean()) + + out = (attn @ v) + # print("self-attention output:", out.max(), out.min()) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("output of attention layer:", out.max(), out.min()) + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, pre_norm): + super(TransformerBlock, self).__init__() + + self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + self.pre_norm = pre_norm + + def forward(self, x): + x = self.conv(x) + # print("restormer block input:", x.max(), x.min()) + if self.pre_norm: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + else: + x = self.norm1(x + self.attn(x)) + x = self.norm2(x + self.ffn(x)) + # x = self.norm1(self.attn(x)) + # x = self.norm2(self.ffn(x)) + # x = x + self.attn(x) + # x = x + self.ffn(x) + # print("restormer block output:", x.max(), x.min()) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + dim = 48, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, inp_img): + + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + + + return out_dec_level1 + + + + +if __name__ == '__main__': + model = Restormer() + # print(model) + + A = torch.randn((1, 1, 240, 240)) + x = model(A) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/swinIR.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/swinIR.py new file mode 100644 index 0000000000000000000000000000000000000000..5cbfc9c175c02531ac80a05ba4abfc0776f752dd --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/swinIR.py @@ -0,0 +1,874 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 4, 4, 6], + embed_dim=32, num_heads=[6, 6, 6, 6], mlp_ratio=4, upsampler='pixelshuffledirect') + + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = SwinIR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/swin_transformer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..877bdfbffc91012e4ac31f41b838ebb116f4a825 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/swin_transformer.py @@ -0,0 +1,878 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=1, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + # print("shallow feature:", x.shape) + x = self.conv_after_body(self.forward_features(x)) + x + # print("deep feature:", x.shape) + x = self.upsample(x) + # print("after upsample:", x.shape) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + # return SwinIR(upscale=1, img_size=(240, 240), + # window_size=8, img_range=1., depths=[6, 6, 6, 6], + # embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + +if __name__ == '__main__': + height = 240 + width = 240 + window_size = 8 + model = SwinIR(upscale=1, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + x = torch.randn((1, 1, height, width)) + print("input:", x.shape) + x = model(x) + print("output:", x.shape) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/trans_unet.zip b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/trans_unet.zip new file mode 100644 index 0000000000000000000000000000000000000000..e7b71c1287551fd6119efd7c62da2196307998f5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/trans_unet.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:369de2a28036532c446409ee1f7b953d5763641ffd452675613e1bb9bba89eae +size 19354 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/trans_unet/trans_unet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/trans_unet/trans_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..fa61dbadfcaee9b26594307adc90cdd1d2a84e3e --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/trans_unet/trans_unet.py @@ -0,0 +1,489 @@ +# coding=utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import logging +import math + +from os.path import join as pjoin + +import torch +import torch.nn as nn +import numpy as np + +from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm +from torch.nn.modules.utils import _pair +from scipy import ndimage +from . import vit_seg_configs as configs +# import .vit_seg_configs as configs +from .vit_seg_modeling_resnet_skip import ResNetV2 + + +logger = logging.getLogger(__name__) + + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class Attention(nn.Module): + def __init__(self, config, vis): + super(Attention, self).__init__() + self.vis = vis + self.num_attention_heads = config.transformer["num_heads"] + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(config.hidden_size, self.all_head_size) + self.key = Linear(config.hidden_size, self.all_head_size) + self.value = Linear(config.hidden_size, self.all_head_size) + + self.out = Linear(config.hidden_size, config.hidden_size) + self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) + self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + + +class Mlp(nn.Module): + def __init__(self, config): + super(Mlp, self).__init__() + self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) + self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(config.transformer["dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings. + """ + def __init__(self, config, img_size, in_channels=3): + super(Embeddings, self).__init__() + self.hybrid = None + self.config = config + img_size = _pair(img_size) + + if config.patches.get("grid") is not None: # ResNet + grid_size = config.patches["grid"] + # print("image size:", img_size, "grid size:", grid_size) + patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) + patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) + # print("patch_size_real", patch_size_real) + n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) + self.hybrid = True + else: + patch_size = _pair(config.patches["size"]) + n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) + self.hybrid = False + + if self.hybrid: + self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) + in_channels = self.hybrid_model.width * 16 + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=config.hidden_size, + kernel_size=patch_size, + stride=patch_size) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) + + self.dropout = Dropout(config.transformer["dropout_rate"]) + + + def forward(self, x): + # print("input x:", x.shape) ### (4, 3, 240, 240) + if self.hybrid: + x, features = self.hybrid_model(x) + else: + features = None + # print("output of the hybrid model:", x.shape) ### (4, 1024, 15, 15) + x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) + x = x.flatten(2) + x = x.transpose(-1, -2) # (B, n_patches, hidden) + + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings, features + + +class Block(nn.Module): + def __init__(self, config, vis): + super(Block, self).__init__() + self.hidden_size = config.hidden_size + self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn = Mlp(config) + self.attn = Attention(config, vis) + + def forward(self, x): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + def load_from(self, weights, n_block): + ROOT = f"Transformer/encoderblock_{n_block}" + with torch.no_grad(): + query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() + key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() + value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() + out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() + + query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) + key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) + value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) + out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) + + self.attn.query.weight.copy_(query_weight) + self.attn.key.weight.copy_(key_weight) + self.attn.value.weight.copy_(value_weight) + self.attn.out.weight.copy_(out_weight) + self.attn.query.bias.copy_(query_bias) + self.attn.key.bias.copy_(key_bias) + self.attn.value.bias.copy_(value_bias) + self.attn.out.bias.copy_(out_bias) + + mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() + mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() + mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() + mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() + + self.ffn.fc1.weight.copy_(mlp_weight_0) + self.ffn.fc2.weight.copy_(mlp_weight_1) + self.ffn.fc1.bias.copy_(mlp_bias_0) + self.ffn.fc2.bias.copy_(mlp_bias_1) + + self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) + self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) + self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) + self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) + + +class Encoder(nn.Module): + def __init__(self, config, vis): + super(Encoder, self).__init__() + self.vis = vis + self.layer = nn.ModuleList() + self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) + for _ in range(config.transformer["num_layers"]): + layer = Block(config, vis) + self.layer.append(copy.deepcopy(layer)) + + def forward(self, hidden_states): + attn_weights = [] + for layer_block in self.layer: + hidden_states, weights = layer_block(hidden_states) + if self.vis: + attn_weights.append(weights) + encoded = self.encoder_norm(hidden_states) + return encoded, attn_weights + + +class Transformer(nn.Module): + def __init__(self, config, img_size, vis): + super(Transformer, self).__init__() + self.embeddings = Embeddings(config, img_size=img_size) + self.encoder = Encoder(config, vis) + # print(self.encoder) + + def forward(self, input_ids): + embedding_output, features = self.embeddings(input_ids) ### "features" store the features from different layers. + # print("embedding_output:", embedding_output.shape) ## (B, n_patch, hidden) + encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) + # print("encoded feature:", encoded.shape) + return encoded, attn_weights, features + + # return embedding_output, features + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + bn = nn.BatchNorm2d(out_channels) + + super(Conv2dReLU, self).__init__(conv, bn, relu) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels=0, + use_batchnorm=True, + ): + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + + def forward(self, x, skip=None): + x = self.up(x) + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class ReconstructionHead(nn.Sequential): + + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() + super().__init__(conv2d, upsampling) + + +class DecoderCup(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + head_channels = 512 + self.conv_more = Conv2dReLU( + config.hidden_size, + head_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + decoder_channels = config.decoder_channels + in_channels = [head_channels] + list(decoder_channels[:-1]) + out_channels = decoder_channels + + if self.config.n_skip != 0: + skip_channels = self.config.skip_channels + for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip + skip_channels[3-i]=0 + + else: + skip_channels=[0,0,0,0] + + blocks = [ + DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) + ] + self.blocks = nn.ModuleList(blocks) + # print(self.blocks) + + def forward(self, hidden_states, features=None): + B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) + h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) + x = hidden_states.permute(0, 2, 1) + x = x.contiguous().view(B, hidden, h, w) + x = self.conv_more(x) + for i, decoder_block in enumerate(self.blocks): + if features is not None: + skip = features[i] if (i < self.config.n_skip) else None + else: + skip = None + x = decoder_block(x, skip=skip) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): + super(VisionTransformer, self).__init__() + self.num_classes = num_classes + self.zero_head = zero_head + self.classifier = config.classifier + self.transformer = Transformer(config, img_size, vis) + self.decoder = DecoderCup(config) + self.segmentation_head = ReconstructionHead( + in_channels=config['decoder_channels'][-1], + out_channels=config['n_classes'], + kernel_size=3, + ) + self.activation = nn.Tanh() + self.config = config + + def forward(self, x): + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + ### hybrid CNN and transformer to extract features from the input images. + x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) + + # ### extract features with CNN only. + # x, features = self.transformer(x) # (B, n_patch, hidden) + + x = self.decoder(x, features) + # logits = self.segmentation_head(x) + logits = self.activation(self.segmentation_head(x)) + # print("logits:", logits.shape) + return logits + + def load_from(self, weights): + with torch.no_grad(): + + res_weight = weights + self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) + self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) + + self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) + self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) + + posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) + + posemb_new = self.transformer.embeddings.position_embeddings + if posemb.size() == posemb_new.size(): + self.transformer.embeddings.position_embeddings.copy_(posemb) + elif posemb.size()[1]-1 == posemb_new.size()[1]: + posemb = posemb[:, 1:] + self.transformer.embeddings.position_embeddings.copy_(posemb) + else: + logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) + ntok_new = posemb_new.size(1) + if self.classifier == "seg": + _, posemb_grid = posemb[:, :1], posemb[0, 1:] + gs_old = int(np.sqrt(len(posemb_grid))) + gs_new = int(np.sqrt(ntok_new)) + print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb = posemb_grid + self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) + + # Encoder whole + for bname, block in self.transformer.encoder.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=uname) + + if self.transformer.embeddings.hybrid: + self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) + gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) + gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) + self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) + self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) + + for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(res_weight, n_block=bname, n_unit=uname) + +CONFIGS = { + 'ViT-B_16': configs.get_b16_config(), + 'ViT-B_32': configs.get_b32_config(), + 'ViT-L_16': configs.get_l16_config(), + 'ViT-L_32': configs.get_l32_config(), + 'ViT-H_14': configs.get_h14_config(), + 'R50-ViT-B_16': configs.get_r50_b16_config(), + 'R50-ViT-L_16': configs.get_r50_l16_config(), + 'testing': configs.get_testing(), +} + + + +def build_model(args): + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + print(net) + # net.load_from(weights=np.load(config_vit.pretrained_path)) + return net + + +if __name__ == "__main__": + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + net.load_from(weights=np.load(config_vit.pretrained_path)) + + input_data = torch.rand(4, 1, 240, 240).cuda() + print("input:", input_data.shape) + output = net(input_data) + print("output:", output.shape) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/trans_unet/vit_seg_configs.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/trans_unet/vit_seg_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..bed7d3346e30328a38ac6fef1621f788d905df1e --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/trans_unet/vit_seg_configs.py @@ -0,0 +1,132 @@ +import ml_collections + +def get_b16_config(): + """Returns the ViT-B/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 768 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 3072 + config.transformer.num_heads = 12 + config.transformer.num_layers = 4 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + + config.classifier = 'seg' + config.representation_size = None + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' + config.patch_size = 16 + + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_testing(): + """Returns a minimal configuration for testing.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 1 + config.transformer.num_heads = 1 + config.transformer.num_layers = 1 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config + + +def get_r50_b16_config(): + """Returns the Resnet50 + ViT-B/16 configuration.""" + config = get_b16_config() + # config.patches.grid = (16, 16) + config.patches.grid = (15, 15) ### for the input image with size (240, 240) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'recon' + config.pretrained_path = '/home/xiaohan/workspace/MSL_MRI/code/pretrained_model/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 1 + config.n_skip = 3 + config.activation = 'tanh' + + return config + + +def get_b32_config(): + """Returns the ViT-B/32 configuration.""" + config = get_b16_config() + config.patches.size = (32, 32) + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' + return config + + +def get_l16_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1024 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 4096 + config.transformer.num_heads = 16 + config.transformer.num_layers = 24 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.representation_size = None + + # custom + config.classifier = 'seg' + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_r50_l16_config(): + """Returns the Resnet50 + ViT-L/16 configuration. customized """ + config = get_l16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_l32_config(): + """Returns the ViT-L/32 configuration.""" + config = get_l16_config() + config.patches.size = (32, 32) + return config + + +def get_h14_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (14, 14)}) + config.hidden_size = 1280 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 5120 + config.transformer.num_heads = 16 + config.transformer.num_layers = 32 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + + return config \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae80ef7d96c46cb91bda849901aef34008e62ed --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py @@ -0,0 +1,160 @@ +import math + +from os.path import join as pjoin +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +class StdConv2d(nn.Conv2d): + + def forward(self, x): + w = self.weight + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / torch.sqrt(v + 1e-5) + return F.conv2d(x, w, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +def conv3x3(cin, cout, stride=1, groups=1, bias=False): + return StdConv2d(cin, cout, kernel_size=3, stride=stride, + padding=1, bias=bias, groups=groups) + + +def conv1x1(cin, cout, stride=1, bias=False): + return StdConv2d(cin, cout, kernel_size=1, stride=stride, + padding=0, bias=bias) + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + """ + + def __init__(self, cin, cout=None, cmid=None, stride=1): + super().__init__() + cout = cout or cin + cmid = cmid or cout//4 + + self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv1 = conv1x1(cin, cmid, bias=False) + self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! + self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) + self.conv3 = conv1x1(cmid, cout, bias=False) + self.relu = nn.ReLU(inplace=True) + + if (stride != 1 or cin != cout): + # Projection also with pre-activation according to paper. + self.downsample = conv1x1(cin, cout, stride, bias=False) + self.gn_proj = nn.GroupNorm(cout, cout) + + def forward(self, x): + + # Residual branch + residual = x + if hasattr(self, 'downsample'): + residual = self.downsample(x) + residual = self.gn_proj(residual) + + # Unit's branch + y = self.relu(self.gn1(self.conv1(x))) + y = self.relu(self.gn2(self.conv2(y))) + y = self.gn3(self.conv3(y)) + + y = self.relu(residual + y) + return y + + def load_from(self, weights, n_block, n_unit): + conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) + conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) + conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) + + gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) + gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) + + gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) + gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) + + gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) + gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) + + self.conv1.weight.copy_(conv1_weight) + self.conv2.weight.copy_(conv2_weight) + self.conv3.weight.copy_(conv3_weight) + + self.gn1.weight.copy_(gn1_weight.view(-1)) + self.gn1.bias.copy_(gn1_bias.view(-1)) + + self.gn2.weight.copy_(gn2_weight.view(-1)) + self.gn2.bias.copy_(gn2_bias.view(-1)) + + self.gn3.weight.copy_(gn3_weight.view(-1)) + self.gn3.bias.copy_(gn3_bias.view(-1)) + + if hasattr(self, 'downsample'): + proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) + proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) + proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) + + self.downsample.weight.copy_(proj_conv_weight) + self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) + self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode.""" + + def __init__(self, block_units, width_factor): + super().__init__() + width = int(64 * width_factor) + self.width = width + + self.root = nn.Sequential(OrderedDict([ + ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), + ('gn', nn.GroupNorm(32, width, eps=1e-6)), + ('relu', nn.ReLU(inplace=True)), + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) + ])) + + self.body = nn.Sequential(OrderedDict([ + ('block1', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], + ))), + ('block2', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], + ))), + ('block3', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], + ))), + ])) + + def forward(self, x): + features = [] + b, c, in_size, _ = x.size() + x = self.root(x) + features.append(x) + x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) + for i in range(len(self.body)-1): + x = self.body[i](x) + right_size = int(in_size / 4 / (i+1)) + if x.size()[2] != right_size: + pad = right_size - x.size()[2] + assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) + feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) + feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] + else: + feat = x + features.append(feat) + x = self.body[-1](x) + return x, features[::-1] \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/transformer_modules.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/transformer_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..900042884305619e39ad9b792b3b1936dfbd37b3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/transformer_modules.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=4, + num_heads=4, + mlp_ratio=3) + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/unet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..d6dd935cd34a121d8079a8f5184d4ea29e15c8da --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/unet.py @@ -0,0 +1,196 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F + + +class Unet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = image + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + # print("unet feature range:", output.max(), output.min()) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/unet_restormer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/unet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f78aa5f192957b2706c6f691dfc66c427b65ae --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/unet_restormer.py @@ -0,0 +1,234 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + # print("Unet feature range:", output.max(), output.min()) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/unet_transformer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/unet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..982fe23cec321de6b636e04c51b62037054ea432 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/unet_transformer.py @@ -0,0 +1,346 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +对于Unet中的high-level feature, 用Transformer进行特征交互,然后和Transformer之前的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Unet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = image + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack, output = self.encoder(output) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output) # [4, 225, 256], 225 is the number of patches. + patch_embed_input = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1])), int(np.sqrt(feature_output.shape[1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[-1], h, w) # [4,512*2,24,24] + + output = torch.cat((output, feature_output), 1) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output = self.decoder(output, stack) + + return output + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Transformer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/unimodal_transformer.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/unimodal_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c99214e10bee2f7a1be49f8560799ffbc14ed0a5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/compare_models/unimodal_transformer.py @@ -0,0 +1,112 @@ +import torch + +from torch import nn +from .transformer_modules import build_transformer +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(): + return ReconstructionHead(input_dim=1, hidden_dim=12) + + + +class CMMT(nn.Module): + + def __init__(self, args): + super(CMMT, self).__init__() + + self.head = build_head() + + HEAD_HIDDEN_DIM = 12 + PATCH_SIZE = 16 + INPUT_SIZE = 240 + OUTPUT_DIM = 1 + + patch_dim = HEAD_HIDDEN_DIM* PATCH_SIZE ** 2 + num_patches = (INPUT_SIZE // PATCH_SIZE) ** 2 + + self.transformer = build_transformer(patch_dim) + + self.patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=PATCH_SIZE, p2=PATCH_SIZE), + ) + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, patch_dim)) + + + self.p1 = PATCH_SIZE + self.p2 = PATCH_SIZE + + self.tail = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x): + + b,_ , h, w = x.shape + + x = self.head(x) + + x= self.patch_embbeding(x) + + x += self.pos_embedding + + x = self.transformer(x) # b HW p1p2c + + c = int(x.shape[2]/(self.p1*self.p2)) + H = int(h/self.p1) + W = int(w/self.p2) + + x = x.reshape(b, H, W, self.p1, self.p2, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w,) + x = self.tail(x) + + return x + + +def build_model(args): + return CMMT(args) + + + + + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/modules.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..bad92a6d6709130c4ad1cc49d5094958d44d7e71 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/modules.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + +USE_PYTORCH_IN = False + + +###################################################################### +# Superclass of all Modules that take two inputs +###################################################################### +class TwoInputModule(nn.Module): + def forward(self, input1, input2): + raise NotImplementedError + +###################################################################### +# A (sort of) hacky way to create a module that takes two inputs (e.g. x and z) +# and returns one output (say o) defined as follows: +# o = module2.forward(module1.forward(x), z) +# Note that module2 MUST support two inputs as well. +###################################################################### +class MergeModule(TwoInputModule): + def __init__(self, module1, module2): + """ module1 could be any module (e.g. Sequential of several modules) + module2 must accept two inputs + """ + super(MergeModule, self).__init__() + self.module1 = module1 + self.module2 = module2 + + def forward(self, input1, input2): + output1 = self.module1.forward(input1) + output2 = self.module2.forward(output1, input2) + return output2 + +###################################################################### +# A (sort of) hacky way to create a container that takes two inputs (e.g. x and z) +# and applies a sequence of modules (exactly like nn.Sequential) but MergeModule +# is one of its submodules it applies it to both inputs +###################################################################### +class TwoInputSequential(nn.Sequential, TwoInputModule): + def __init__(self, *args): + super(TwoInputSequential, self).__init__(*args) + + def forward(self, input1, input2): + """overloads forward function in parent calss""" + + for module in self._modules.values(): + if isinstance(module, TwoInputModule): + input1 = module.forward(input1, input2) + else: + input1 = module.forward(input1) + return input1 + + +###################################################################### +# A standard instance norm module. +# Since the pytorch instance norm used BatchNorm as a base and thus is +# different from the standard implementation. +###################################################################### +class InstanceNorm(nn.Module): + def __init__(self, num_features, affine=True, eps=1e-5): + """`num_features` number of feature channels + """ + super(InstanceNorm, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + + self.scale = Parameter(torch.Tensor(num_features)) + self.shift = Parameter(torch.Tensor(num_features)) + + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + self.scale.data.normal_(mean=0., std=0.02) + self.shift.data.zero_() + + def forward(self, input): + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + centered_x = x_reshaped - mean + std = torch.rsqrt((centered_x ** 2).mean(2, keepdim=True) + self.eps) + norm_features = (centered_x * std).view(*size) + + # broadcast on the batch dimension, hight and width dimensions + if self.affine: + output = norm_features * self.scale[:,None,None] + self.shift[:,None,None] + else: + output = norm_features + + return output + +InstanceNorm2d = nn.InstanceNorm2d if USE_PYTORCH_IN else InstanceNorm + +###################################################################### +# A module implementing conditional instance norm. +# Takes two inputs: x (input features) and z (latent codes) +###################################################################### +class CondInstanceNorm(TwoInputModule): + def __init__(self, x_dim, z_dim, eps=1e-5): + """`x_dim` dimensionality of x input + `z_dim` dimensionality of z latents + """ + super(CondInstanceNorm, self).__init__() + self.eps = eps + self.shift_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + self.scale_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + + def forward(self, input, noise): + + shift = self.shift_conv.forward(noise) + scale = self.scale_conv.forward(noise) + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + var = x_reshaped.var(2, keepdim=True) + std = torch.rsqrt(var + self.eps) + norm_features = ((x_reshaped - mean) * std).view(*size) + output = norm_features * scale + shift + return output + + +###################################################################### +# A modified resnet block which allows for passing additional noise input +# to be used for conditional instance norm +###################################################################### +class CINResnetBlock(TwoInputModule): + def __init__(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + super(CINResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + for idx, module in enumerate(self.conv_block): + self.add_module(str(idx), module) + + def build_conv_block(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [ + MergeModule( + nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(x_dim, z_dim) + ), + nn.ReLU(True) + ] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + InstanceNorm2d(x_dim, affine=True)] + + return TwoInputSequential(*conv_block) + + def forward(self, x, noise): + out = self.conv_block(x, noise) + out = self.relu(x + out) + return out + +###################################################################### +# Define a resnet block +###################################################################### +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [nn.ReLU(True)] + + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = self.conv_block(x) + out = self.relu(x + out) + return out diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/mynet.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/mynet.py new file mode 100644 index 0000000000000000000000000000000000000000..042c6b2ca805b7b3772f8f244edeffa812841296 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/networks_time/mynet.py @@ -0,0 +1,467 @@ +import torch, math +from torch import nn +from networks_time import common_freq as common + + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + + +class Block_Sequential(nn.Module): + def __init__(self, block1, block2): + super(Block_Sequential, self).__init__() + self.block1 = block1 + self.block2 = block2 + + def forward(self, x, t=None): + x = self.block1(x) + x = self.block2(x, t) + return x + + +class DiffTwoBranch(nn.Module): + def __init__(self, args): + super(DiffTwoBranch, self).__init__() + + num_group = 4 + num_every_group = args.base_num_every_group + self.args = args + + self.ch = args.num_channels + self.temb_ch = args.num_channels * 4 + args.temb_channels = self.temb_ch + + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(args.num_channels, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ]) + + self.init_T2_frq_branch(args) + self.init_T2_spa_branch(args, num_every_group) + self.init_T2_fre_spa_fusion(args) + + self.init_T1_frq_branch(args) + self.init_T1_spa_branch(args, num_every_group) + + self.init_modality_fre_fusion(args) + self.init_modality_spa_fusion(args) + + + def init_T2_frq_branch(self, args): + ### T2frequency branch + self.head_fre = common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act) + + self.down1_fre = Block_Sequential(*[common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args)]) + + self.down1_fre_mo = common.FreBlock9(args.num_features, args) + + modules_down2_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down2_fre = Block_Sequential(*modules_down2_fre) + + self.down2_fre_mo = common.FreBlock9(args.num_features, args) + + modules_down3_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down3_fre = Block_Sequential(*modules_down3_fre) + self.down3_fre_mo = common.FreBlock9(args.num_features, args) + + self.neck_fre = common.FreBlock9(args.num_features, args) + + self.neck_fre_mo = common.FreBlock9(args.num_features, args) + + modules_up1_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up1_fre = Block_Sequential(*modules_up1_fre) + self.up1_fre_mo = common.FreBlock9(args.num_features, args) + + modules_up2_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up2_fre = Block_Sequential(*modules_up2_fre) + self.up2_fre_mo = common.FreBlock9(args.num_features, args) + + modules_up3_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up3_fre = Block_Sequential(*modules_up3_fre) + self.up3_fre_mo = common.FreBlock9(args.num_features, args) + + # define tail module + self.tail_fre = common.ConvBNReLU2D(args.num_features, out_channels=args.num_channels, kernel_size=3, padding=1, + act=args.act) + + def init_T2_spa_branch(self, args, num_every_group): + ### spatial branch + modules_head = [] + self.head = common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act) + + modules_down1 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + ] + self.down1 = Block_Sequential(*modules_down1) + + + self.down1_mo = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + modules_down2 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + ] + self.down2 = Block_Sequential(*modules_down2) + + self.down2_mo = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + modules_down3 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + ] + self.down3 = Block_Sequential(*modules_down3) + self.down3_mo = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + self.neck = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + self.neck_mo = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + modules_up1 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + ] + self.up1 = Block_Sequential(*modules_up1) + + self.up1_mo = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + modules_up2 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + ] + self.up2 = Block_Sequential(*modules_up2) + self.up2_mo = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + + modules_up3 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + ] + self.up3 = Block_Sequential(*modules_up3) + self.up3_mo = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + # define tail module + self.tail = common.ConvBNReLU2D(args.num_features, out_channels=args.num_channels, kernel_size=3, padding=1, + act=args.act) + + def init_T2_fre_spa_fusion(self, args): + ### T2 frq & spa fusion part + conv_fuse = [] + for i in range(14): + conv_fuse.append(common.FuseBlock7(args.num_features)) + self.conv_fuse = nn.Sequential(*conv_fuse) + + def init_T1_frq_branch(self, args): + ### T2frequency branch + self.head_fre_T1 = common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act) + + modules_down1_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + + self.down1_fre_T1 = Block_Sequential(*modules_down1_fre) + self.down1_fre_mo_T1 = common.FreBlock9(args.num_features, args) + + modules_down2_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down2_fre_T1 = Block_Sequential(*modules_down2_fre) + + self.down2_fre_mo_T1 = common.FreBlock9(args.num_features, args) + + modules_down3_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down3_fre_T1 = Block_Sequential(*modules_down3_fre) + self.down3_fre_mo_T1 = common.FreBlock9(args.num_features, args) + + + self.neck_fre_T1 = common.FreBlock9(args.num_features, args) + self.neck_fre_mo_T1 = common.FreBlock9(args.num_features, args) + + def init_T1_spa_branch(self, args, num_every_group): + ### spatial branch + self.head_T1 = common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act) + + modules_down1 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down1_T1 = Block_Sequential(*modules_down1) + + + self.down1_mo_T1 = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + + modules_down2 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down2_T1 = Block_Sequential(*modules_down2) + + self.down2_mo_T1 = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + + modules_down3 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down3_T1 = Block_Sequential(*modules_down3) + self.down3_mo_T1 = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + + self.neck_T1 = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + + self.neck_mo_T1 = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + + + def init_modality_fre_fusion(self, args): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(args.num_features)) + self.conv_fuse_fre = nn.Sequential(*conv_fuse) + + def init_modality_spa_fusion(self, args): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(args.num_features)) + self.conv_fuse_spa = nn.Sequential(*conv_fuse) + + + def forward(self, main, aux, t=None): + + # self.temb_proj = torch.nn.Linear(temb_channels, + # out_channels) + # h = self.norm1(h) + # h = nonlinearity(h) + # h = self.conv1(h) + # + # h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + + temb = None + + + if not isinstance(t, type(None)): + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + + #### T1 fre encoder + t1_fre = self.head_fre_T1(aux) # 128 + + down1_fre_t1 = self.down1_fre_T1(t1_fre)# 64 + down1_fre_mo_t1 = self.down1_fre_mo_T1(down1_fre_t1) + + down2_fre_t1 = self.down2_fre_T1(down1_fre_mo_t1) # 32 + down2_fre_mo_t1 = self.down2_fre_mo_T1(down2_fre_t1) + + down3_fre_t1 = self.down3_fre_T1(down2_fre_mo_t1) # 16 + down3_fre_mo_t1 = self.down3_fre_mo_T1(down3_fre_t1) + + neck_fre_t1 = self.neck_fre_T1(down3_fre_mo_t1) # 16 + neck_fre_mo_t1 = self.neck_fre_mo_T1(neck_fre_t1) + + + #### T2 fre encoder and T1 & T2 fre fusion + x_fre = self.head_fre(main) # 128 + x_fre_fuse = self.conv_fuse_fre[0](t1_fre, x_fre) + + down1_fre = self.down1_fre(x_fre_fuse, temb)# 64 + down1_fre_mo = self.down1_fre_mo(down1_fre, temb) + down1_fre_mo_fuse = self.conv_fuse_fre[1](down1_fre_mo_t1, down1_fre_mo) + + down2_fre = self.down2_fre(down1_fre_mo_fuse, temb) # 32 + down2_fre_mo = self.down2_fre_mo(down2_fre, temb) + down2_fre_mo_fuse = self.conv_fuse_fre[2](down2_fre_mo_t1, down2_fre_mo) + + down3_fre = self.down3_fre(down2_fre_mo_fuse, temb) # 16 + down3_fre_mo = self.down3_fre_mo(down3_fre, temb) + down3_fre_mo_fuse = self.conv_fuse_fre[3](down3_fre_mo_t1, down3_fre_mo) + + neck_fre = self.neck_fre(down3_fre_mo_fuse, temb) # 16 + neck_fre_mo = self.neck_fre_mo(neck_fre, temb) + neck_fre_mo_fuse = self.conv_fuse_fre[4](neck_fre_mo_t1, neck_fre_mo) + + + #### T2 fre decoder + neck_fre_mo = neck_fre_mo_fuse + down3_fre_mo_fuse + + up1_fre = self.up1_fre(neck_fre_mo, temb) # 32 + up1_fre_mo = self.up1_fre_mo(up1_fre, temb) + up1_fre_mo = up1_fre_mo + down2_fre_mo_fuse + + up2_fre = self.up2_fre(up1_fre_mo, temb) # 64 + up2_fre_mo = self.up2_fre_mo(up2_fre, temb) + up2_fre_mo = up2_fre_mo + down1_fre_mo_fuse + + up3_fre = self.up3_fre(up2_fre_mo, temb) # 128 + up3_fre_mo = self.up3_fre_mo(up3_fre, temb) + up3_fre_mo = up3_fre_mo + x_fre_fuse + + res_fre = self.tail_fre(up3_fre_mo) + + #### T1 spa encoder + x_t1 = self.head_T1(aux) # 128 + + down1_t1 = self.down1_T1(x_t1) # 64 + down1_mo_t1 = self.down1_mo_T1(down1_t1) + + down2_t1 = self.down2_T1(down1_mo_t1) # 32 + down2_mo_t1 = self.down2_mo_T1(down2_t1) # 32 + + down3_t1 = self.down3_T1(down2_mo_t1) # 16 + down3_mo_t1 = self.down3_mo_T1(down3_t1) # 16 + + neck_t1 = self.neck_T1(down3_mo_t1) # 16 + neck_mo_t1 = self.neck_mo_T1(neck_t1) + + #### T2 spa encoder and fusion + x = self.head(main) # 128 temb + + x_fuse = self.conv_fuse_spa[0](x_t1, x) + down1 = self.down1(x_fuse, temb) # 64 + down1_fuse = self.conv_fuse[0](down1_fre, down1) + down1_mo = self.down1_mo(down1_fuse, temb) + down1_fuse_mo = self.conv_fuse[1](down1_fre_mo_fuse, down1_mo) + + down1_fuse_mo_fuse = self.conv_fuse_spa[1](down1_mo_t1, down1_fuse_mo) + down2 = self.down2(down1_fuse_mo_fuse, temb) # 32 + down2_fuse = self.conv_fuse[2](down2_fre, down2) + down2_mo = self.down2_mo(down2_fuse, temb) # 32 + down2_fuse_mo = self.conv_fuse[3](down2_fre_mo, down2_mo) + + down2_fuse_mo_fuse = self.conv_fuse_spa[2](down2_mo_t1, down2_fuse_mo) + down3 = self.down3(down2_fuse_mo_fuse, temb) # 16 + down3_fuse = self.conv_fuse[4](down3_fre, down3) + down3_mo = self.down3_mo(down3_fuse, temb) # 16 + down3_fuse_mo = self.conv_fuse[5](down3_fre_mo, down3_mo) + + down3_fuse_mo_fuse = self.conv_fuse_spa[3](down3_mo_t1, down3_fuse_mo) + neck = self.neck(down3_fuse_mo_fuse, temb) # 16 + neck_fuse = self.conv_fuse[6](neck_fre, neck) + neck_mo = self.neck_mo(neck_fuse, temb) + neck_mo = neck_mo + down3_mo + neck_fuse_mo = self.conv_fuse[7](neck_fre_mo, neck_mo) + + neck_fuse_mo_fuse = self.conv_fuse_spa[4](neck_mo_t1, neck_fuse_mo) + #### T2 spa decoder + up1 = self.up1(neck_fuse_mo_fuse, temb) # 32 + up1_fuse = self.conv_fuse[8](up1_fre, up1) + up1_mo = self.up1_mo(up1_fuse, temb) + up1_mo = up1_mo + down2_mo + up1_fuse_mo = self.conv_fuse[9](up1_fre_mo, up1_mo) + + up2 = self.up2(up1_fuse_mo, temb) # 64 + up2_fuse = self.conv_fuse[10](up2_fre, up2) + up2_mo = self.up2_mo(up2_fuse, temb) + up2_mo = up2_mo + down1_mo + up2_fuse_mo = self.conv_fuse[11](up2_fre_mo, up2_mo) + + up3 = self.up3(up2_fuse_mo, temb) # 128 + + up3_fuse = self.conv_fuse[12](up3_fre, up3) + up3_mo = self.up3_mo(up3_fuse, temb) + + up3_mo = up3_mo + x + up3_fuse_mo = self.conv_fuse[13](up3_fre_mo, up3_mo) + + # import matplotlib.pyplot as plt + # plt.axis('off') + # plt.imshow((255*up3_fre_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_fre_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + + # plt.axis('off') + # plt.imshow((255*up3_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + + # plt.axis('off') + # plt.imshow((255*up3_fuse_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_fuse_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + # breakpoint() + + res = self.tail(up3_fuse_mo) + + return {'img_out': res + main, 'img_fre': res_fre + main} + +def make_model(args): + return DiffTwoBranch(args) + + + +if __name__ == "__main__": + # Test the model + from utils.option import args + + network = DiffTwoBranch(args) + # network = build_model_from_name(args).cuda() + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + # network = nn.SyncBatchNorm.convert_sync_batchnorm(network) + + n_parameters = sum(p.numel() for p in network.parameters() if p.requires_grad) + print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + # Test the model + data = torch.randn(1, 1, 128, 128)#.cuda + out = network(data, data) + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/test_brats.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/test_brats.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f124e756b6826a35ab5e3ac0e5bb32d007865d --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/test_brats.py @@ -0,0 +1,310 @@ +import os +import sys +from tqdm import tqdm +import argparse +import logging +from skimage import io + +from torchvision import transforms +from torch.utils.data import DataLoader +from dataloaders.BRATS_dataloader_new import Hybrid as MyDataset +from dataloaders.BRATS_dataloader_new import ToTensor +from networks.mynet import TwoBranch +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity +from utils.option import args + + +def normalise_mse(gt, pred): + """Compute Normalized Mean Squared Error (NMSE)""" + return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2 + + + +parser = argparse.ArgumentParser() +parser.add_argument('--root_path', type=str, default='/home/xiaohan/datasets/BRATS_dataset/BRATS_2020_images/selected_images/') +parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') +parser.add_argument('--low_field_SNR', type=int, default=15, help='SNR of the simulated low-field image') +parser.add_argument('--phase', type=str, default='test', help='Name of phase') +parser.add_argument('--gpu', type=str, default='0', help='GPU to use') +parser.add_argument('--exp', type=str, default='msl_model', help='model_name') +parser.add_argument('--seed', type=int, default=1337, help='random seed') +parser.add_argument('--base_lr', type=float, default=0.0002, help='maximum epoch numaber to train') + +parser.add_argument('--model_name', type=str, default='unet_single', help='model_name') +parser.add_argument('--relation_consistency', type=str, default='False', help='regularize the consistency of feature relation') +parser.add_argument('--norm', type=str, default='False', help='Norm Layer between UNet and Transformer') +parser.add_argument('--input_normalize', type=str, default='mean_std', help='choose from [min_max, mean_std, divide]') +parser.add_argument('--test_sample', default="Ksample", help="Ksample | ColdDiffusion | DDPM") + +# args = parser.parse_args() + +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + + +from utils.utils import * +from frequency_diffusion.degradation.k_degradation import get_ksu_kernel, apply_tofre, apply_to_spatial +from networks_time.mynet import DiffTwoBranch + +DEBUG = args.DEBUG +use_time_model = args.use_time_model +use_kspace = args.use_kspace +use_t2_in = True + +num_timesteps = 5 +image_size = 240 + +if args.MRIDOWN == "4X": + accelerate_mask = np.load("./dataloaders/example_mask/brats_4X_mask.npy") + accelerate_mask = torch.from_numpy(accelerate_mask).unsqueeze(0).clone().float() + print("accelerate_mask shape =", accelerate_mask.shape) +else: + accelerate_mask = None + +k_file = f"./dataloaders/example_mask/brats_{args.ACCELERATIONS[0]}_kspace_mask.npy" +if os.path.exists(k_file): + kspace_masks = np.load(k_file) + kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + +else: + # Output a list of k-space kernels + kspace_masks = get_ksu_kernel(num_timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=args.ACCELERATIONS[0], + accelerate_mask=accelerate_mask + ) + kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + + +def normalize_output(out_img): + out_img = (out_img - out_img.min())/(out_img.max() - out_img.min() + 1e-8) + return out_img + + +test_sample = args.test_sample # Ksample | ColdDiffusion | DDPM + +if __name__ == "__main__": + ## make logger file + if use_kspace: + snapshot_path = snapshot_path.rstrip("/") + f'_t{num_timesteps}_kspace/' + + if use_time_model: + snapshot_path = snapshot_path.rstrip("/") + '_time/' + + if not isinstance(args.test_tag, type(None)): + snapshot_path = snapshot_path.rstrip("/") + f'_{args.test_tag}/' + + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + if use_time_model: + network = DiffTwoBranch(args).cuda() + else: + network = TwoBranch(args).cuda() + + + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + + n_parameters = sum(p.numel() for p in network.parameters() if p.requires_grad) + print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + db_test = MyDataset(split='test', MRIDOWN=args.MRIDOWN, SNR=args.low_field_SNR, + transform=transforms.Compose([ToTensor()]), + base_dir=test_data_path, input_normalize = args.input_normalize) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=2, pin_memory=True) + + if args.phase == 'test': + + save_mode_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + print('load weights from ' + save_mode_path) + checkpoint = torch.load(save_mode_path) + network.load_state_dict(checkpoint['network']) + network.eval() + cnt = 0 + save_path = snapshot_path + '/result_case/' + feature_save_path = snapshot_path + '/feature_visualization/' + if not os.path.exists(save_path): + os.makedirs(save_path) + if not os.path.exists(feature_save_path): + os.makedirs(feature_save_path) + + + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all, t2_NMSE_all = [], [], [], [] + + for (sampled_batch, sample_stats) in tqdm(testloader, ncols=70): + cnt += 1 + + print('processing ' + str(cnt) + ' image') + t1_in, t1, t2_in, t2 = sampled_batch['image_in'].cuda(), sampled_batch['image'].cuda(), \ + sampled_batch['target_in'].cuda(), sampled_batch['target'].cuda() + t1_krecon, t2_krecon = sampled_batch['image_krecon'].cuda(), sampled_batch['target_krecon'].cuda() + + t2_mean = sample_stats['t2_mean'].data.cpu().numpy()[0] + t2_std = sample_stats['t2_std'].data.cpu().numpy()[0] + + + t1_out, t2_out = None, None + + if use_kspace: + b = t2.shape[0] + t = torch.randint(num_timesteps - 1, num_timesteps, (b,), device=device).long() # t-1 + + mask = kspace_masks[t] + target_fft, _ = apply_tofre(t2.clone(), mask) + fft, mask = apply_tofre(t2_in.clone(), mask) + + fft = target_fft * mask + fft * (1 - mask) # Seems too easy + t2_in = apply_to_spatial(fft) + + + while t >= 0: + if use_time_model: + outputs = network(t2_in, t1_in, t)['img_out'] + else: + outputs = network(t2_in, t1_in)['img_out'] + + if t == 0: + mask = kspace_masks[0] # last one + t2_in = outputs + if use_time_model: + t2_out_2 = network(t2_in, t1_in, t)['img_out'] + else: + t2_out_2 = network(t2_in, t1_in)['img_out'] + else: + + if test_sample == "Ksample": # Ksample | ColdDiffusion | DDPM + + k_full = kspace_masks[-1] + t2_in_fre, k_full = apply_tofre(t2_in, k_full) + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] # get_kspace_kernels(t - 2).cuda() + kt = kspace_masks[t] # self.get_kspace_kernels(t - 1).cuda() # last one + k_residual = kt_sub_1 - kt + + recon_sample_fre, k_residual = apply_tofre(outputs, k_residual) + # fre_amend = recon_sample_fre * k_residual + + t2_in_fre = t2_in_fre * (1 - k_residual) + recon_sample_fre * k_residual + + outputs = apply_to_spatial(t2_in_fre) + t2_in = outputs + + elif test_sample == "ColdDiffusion": + k_full = kspace_masks[-1] + # t2_in_fre, k_full = apply_tofre(t2_in, k_full) + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] # get_kspace_kernels(t - 2).cuda() + kt = kspace_masks[t] # self.get_kspace_kernels(t - 1).cuda() # last one + + k_residual = kt_sub_1 - kt + + recon_sample_fre, k_residual = apply_tofre(outputs, k_residual) + + x_t_hat_fre = recon_sample_fre * kt + x_t_sub_1_hat_fre = recon_sample_fre * kt_sub_1 + + x_t_hat = apply_to_spatial(x_t_hat_fre) + x_t_sub_1_hat = apply_to_spatial(x_t_sub_1_hat_fre) + + outputs = t2_in - x_t_hat + x_t_sub_1_hat + + t2_in = outputs + + elif test_sample == "DDPM": + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] # get_kspace_kernels(t - 2).cuda() + + recon_sample_fre, kt_sub_1 = apply_tofre(outputs, kt_sub_1) + fre_new = recon_sample_fre * kt_sub_1 + + outputs = apply_to_spatial(fre_new) + t2_in = outputs + + t = t - 1 + t2_out = outputs + + else: + t2_out = network(t2_in, t1_in)['img_out'] + t2_out_2 = network(t2_in, t1_in)['img_out'] + + t1_mean = sample_stats['t1_mean'].data.cpu().numpy()[0] + t1_std = sample_stats['t1_std'].data.cpu().numpy()[0] + t2_mean = sample_stats['t2_mean'].data.cpu().numpy()[0] + t2_std = sample_stats['t2_std'].data.cpu().numpy()[0] + + if t1_out is not None: + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_out_img = (np.clip(t1_out.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_krecon_img = (np.clip(t1_krecon.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t2_in_img = (np.clip(t2_in.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_img = (np.clip(t2.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_out_img = (np.clip(t2_out.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_krecon_img = (np.clip(t2_krecon.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_out_2_img = (np.clip(t2_out_2.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + + + io.imsave(save_path + str(cnt) + '_t1.png', bright(t1_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2.png', bright(t2_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2_original.png', t2_img) + io.imsave(save_path + str(cnt) + '_t2_in.png', bright(t2_in_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2_out_original.png', t2_out_img) + io.imsave(save_path + str(cnt) + '_t2_out.png', bright(t2_out_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2_out2.png', bright(t2_out_2_img,0,0.8)) + + # ------------------------------------ + # NMSE: 5.4534 ± 1.5515 + # PSNR: 39.2132 ± 1.6888 + # SSIM: 0.9792 ± 0.0054 + # ------------------------------------ + # Save Path: model/FSMNet_BraTS_8x_kspace//result_case/ + + if t2_out is not None: + t2_out_img[t2_out_img < 0.0] = 0.0 + t2_img[t2_img < 0.0] = 0.0 + MSE = mean_squared_error(t2_img, t2_out_img) + PSNR = peak_signal_noise_ratio(t2_img, t2_out_img) + SSIM = structural_similarity(t2_img, t2_out_img) + nmse = normalise_mse(t2_img/255, t2_out_img/255) + + t2_MSE_all.append(MSE) + t2_PSNR_all.append(PSNR) + t2_SSIM_all.append(SSIM) + t2_NMSE_all.append(nmse) + + print("[t2 MRI] MSE:", MSE, "PSNR:", PSNR, "SSIM:", SSIM, "NMSE:", nmse) + + + print("===> Evaluate Metric <===") + print("Results") + print("-" * 36) + print(f"{test_sample} NMSE: {np.array(t2_NMSE_all).mean() * 100:.4f} ± {np.array(t2_NMSE_all).std() * 100 :.4f}") + # print(f"MSE: {np.array(t2_MSE_all).mean():.4f} ± {np.array(t2_MSE_all).std():.4f}") + print(f"{test_sample} PSNR: {np.array(t2_PSNR_all).mean():.4f} ± {np.array(t2_PSNR_all).std():.4f}") + print(f"{test_sample} SSIM: {np.array(t2_SSIM_all).mean():.4f} ± {np.array(t2_SSIM_all).std():.4f}") + print("-" * 36) + print(f"Save Path: {save_path}") + + + + # print("[T2 MRI:] average MSE:", np.array(t2_MSE_all).mean(), "average PSNR:", np.array(t2_PSNR_all).mean(), "average SSIM:", np.array(t2_SSIM_all).mean()) + # print("[T2 MRI:] average MSE:", np.array(t2_MSE_all).std(), "average PSNR:", np.array(t2_PSNR_all).std(), "average SSIM:", np.array(t2_SSIM_all).std()) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/test_fastmri.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/test_fastmri.py new file mode 100644 index 0000000000000000000000000000000000000000..897b59c7e417f1f16a669ef679daeaf0347a41d4 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/test_fastmri.py @@ -0,0 +1,245 @@ +import os +import sys +import logging +from skimage import io +from skimage import img_as_ubyte + +from torch.utils.data import DataLoader +from networks.mynet import TwoBranch + +from utils.option import args +from tqdm import tqdm +from utils.metric import nmse, psnr, ssim +from collections import defaultdict +from networks_time.mynet import DiffTwoBranch + +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + + +def normalize_output(out_img): + out_img = (out_img - out_img.min())/(out_img.max() - out_img.min() + 1e-8) + return out_img + +from frequency_diffusion.degradation.k_degradation import apply_tofre, apply_to_spatial +from utils.utils import * + +DEBUG = False +use_kspace = args.use_kspace +use_time_model = args.use_time_model +num_timesteps = args.num_timesteps +image_size = args.image_size +snapshot_path=args.snapshot_path + +from frequency_diffusion.degradation.k_degradation import get_ksu_kernel, apply_tofre, apply_to_spatial + + +kspace_masks = np.load(f"./dataloaders/example_mask/kspace_{args.ACCELERATIONS[0]}_mask.npy") +kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + +kspace_masks = get_ksu_kernel(num_timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=args.ACCELERATIONS[0] + ) + +kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + + + +print("kspace_masks shape: ", kspace_masks.shape) + +@torch.no_grad() +def evaluate(model, data_loader, device, save_path): + os.makedirs(save_path, exist_ok=True) + + model.eval() + nmse_meter, psnr_meter, ssim_meter = [], [], [] + direct_nmse, direct_psnr, direct_ssim = [], [], [] + output_dic = defaultdict(dict) + target_dic = defaultdict(dict) + input_dic = defaultdict(dict) + direct_recon_dic = defaultdict(dict) + + flag=0 + last_name='no' + + print("len of data_loader: ", len(data_loader)) + + for data in tqdm(data_loader): + pd, pdfs, _ = data + name = os.path.basename(pdfs[4][0]).split('.')[0] + + target = pdfs[1].to(device) + + mean, std = pdfs[2], pdfs[3] + + fname = pdfs[4] + slice_num = pdfs[5] + + mean = mean.unsqueeze(1).unsqueeze(2).to(device) + std = std.unsqueeze(1).unsqueeze(2).to(device) + + pd_img = pd[1].unsqueeze(1).to(device) + pdfs_img = pdfs[0].unsqueeze(1).to(device) + + pdfs_img_origin = pdfs_img.clone() + + # Degradation + if use_kspace: + b = pdfs_img.shape[0] + t = torch.randint(num_timesteps - 1, num_timesteps, (b,), device=device).long() # t-1 + mask = kspace_masks[t] + fft, mask = apply_tofre(target.clone(), mask) + fft = fft * mask + 0.0 + pdfs_img = apply_to_spatial(fft) + + while t >= 0: + if use_time_model: + outputs = network(pdfs_img, pd_img)['img_out'] + else: + outputs = network(pdfs_img, pd_img, t)['img_out'] + if t == num_timesteps - 1: + direct_recon = outputs + + if t == 0: + pdfs_img = outputs + + else: + k_full = kspace_masks[-1] + faded_recon_sample_fre, k_full = apply_tofre(pdfs_img, k_full) + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t-1] #get_kspace_kernels(t - 2).cuda() + kt = kspace_masks[t] #self.get_kspace_kernels(t - 1).cuda() # last one + k_residual = kt_sub_1 - kt + + recon_sample_fre, k_residual = apply_tofre(outputs, k_residual) + fre_amend = recon_sample_fre * k_residual + faded_recon_sample_fre = faded_recon_sample_fre + fre_amend # * (1-k_residual) + # faded_recon_sample_fre = faded_recon_sample_fre * kt_sub_1 + outputs = apply_to_spatial(faded_recon_sample_fre) + pdfs_img = outputs + + t = t-1 + + else: + outputs = network(pdfs_img, pd_img)['img_out'] + + outputs = outputs.squeeze(1) + direct_recon = direct_recon.squeeze(1) + + outputs_save = outputs[0].cpu().clone().numpy()/6.0 + outputs_save = np.clip(outputs_save, a_min=-1, a_max=1) + target_save = target[0].cpu().clone().numpy()/6.0 + in_save = pdfs_img_origin[0][0].cpu().clone().numpy()/6.0 + + # Not sure if it was correct to convert to ubyte + outputs_save = img_as_ubyte(outputs_save) + target_save = img_as_ubyte(target_save) + in_save = img_as_ubyte(in_save) + + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '.png', target_save) + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '_in.png', in_save) + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '_out.png', outputs_save) + + outputs = outputs * std + mean + target = target * std + mean + inputs = pdfs_img_origin.squeeze(1) * std + mean + direct_recon = direct_recon * std + mean + + output_dic[fname[0]][slice_num[0]] = outputs[0] + target_dic[fname[0]][slice_num[0]] = target[0] + input_dic[fname[0]][slice_num[0]] = inputs[0] + direct_recon_dic[fname[0]][slice_num[0]] = direct_recon[0] + + # print("target/outputs shape: ", target.shape, outputs.shape) + our_nmse = nmse(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + our_psnr = psnr(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + our_ssim = ssim(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + + # print('name:{}, slice:{}, nmse:{}, psnr:{}, ssim:{}'.format(name, slice_num[0], our_nmse, our_psnr, our_ssim)) + + + for name in output_dic.keys(): + f_output = torch.stack([v for _, v in output_dic[name].items()]) + f_target = torch.stack([v for _, v in target_dic[name].items()]) + our_nmse = nmse(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_psnr = psnr(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_ssim = ssim(f_target.cpu().numpy(), f_output.cpu().numpy()) + + nmse_meter.append(our_nmse) + psnr_meter.append(our_psnr) + ssim_meter.append(our_ssim) + + direct_nmse.append(nmse(f_target.cpu().numpy(), torch.stack([v for _, v in direct_recon_dic[name].items()]).cpu().numpy())) + direct_psnr.append(psnr(f_target.cpu().numpy(), torch.stack([v for _, v in direct_recon_dic[name].items()]).cpu().numpy())) + direct_ssim.append(ssim(f_target.cpu().numpy(), torch.stack([v for _, v in direct_recon_dic[name].items()]).cpu().numpy())) + + nmse_meter_score = np.array(nmse_meter) + psnr_meter_score = np.array(psnr_meter) + ssim_meter_score = np.array(ssim_meter) + + direct_nmse_score = np.array(direct_nmse) + direct_psnr_score = np.array(direct_psnr) + direct_ssim_score = np.array(direct_ssim) + + print("===> Evaluate Metric <===") + print("Direct Results") + print("-" * 36) + print(f"NMSE: {np.mean(direct_nmse_score) * 100:.4f} ± {np.std(direct_nmse_score) * 100:.4f}") + print(f"PSNR: {np.mean(direct_psnr_score):.4f} ± {np.std(direct_psnr_score):.4f}") + print(f"SSIM: {np.mean(direct_ssim_score):.4f} ± {np.std(direct_ssim_score):.4f}") + print("-" * 36) + + print("===> Evaluate Metric <===") + print("Results") + print("-" * 36) + print(f"NMSE: {np.mean(nmse_meter_score) * 100:.4f} ± {np.std(nmse_meter_score) * 100:.4f}") + print(f"PSNR: {np.mean(psnr_meter_score):.4f} ± {np.std(psnr_meter_score):.4f}") + print(f"SSIM: {np.mean(ssim_meter_score):.4f} ± {np.std(ssim_meter_score):.4f}") + print("-" * 36) + print(f"Save Path: {save_path}") + + model.train() + return {'NMSE': np.mean(nmse_meter_score), 'PSNR': np.mean(psnr_meter_score), 'SSIM': np.mean(ssim_meter_score)} + + +from dataloaders.fastmri import build_dataset +if __name__ == "__main__": + + + + if use_time_model: + network = DiffTwoBranch(args).cuda() + else: + network = TwoBranch(args).cuda() + + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + + db_test = build_dataset(args, mode='val', use_kspace=use_kspace) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + if args.phase == 'test': + + save_mode_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + # save_mode_path = os.path.join(snapshot_path, 'iter_100000.pth') + print('load weights from ' + save_mode_path) + checkpoint = torch.load(save_mode_path) + + weights_dict = {} + for k, v in checkpoint['network'].items(): + new_k = k.replace('module.', '') if 'module' in k else k + weights_dict[new_k] = v + + network.load_state_dict(weights_dict) + network.eval() + + eval_result = evaluate(network, testloader, device, save_path = snapshot_path + '/result_case/') + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/test_m4raw.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/test_m4raw.py new file mode 100644 index 0000000000000000000000000000000000000000..194a58ba4ce74b58fbed385b43981f2831b41dfd --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/test_m4raw.py @@ -0,0 +1,321 @@ +import os +import sys +import logging +from skimage import io +from skimage import img_as_ubyte + +from torch.utils.data import DataLoader +from networks.mynet import TwoBranch + +from utils.option import args +from tqdm import tqdm +from utils.metric import nmse, psnr, ssim +from collections import defaultdict +from networks_time.mynet import DiffTwoBranch + +test_data_path = args.root_path + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +use_new_dataloader = True + + +# Results + + +def normalize_output(out_img): + out_img = (out_img - out_img.min()) / (out_img.max() - out_img.min() + 1e-8) + return out_img + + +from frequency_diffusion.degradation.k_degradation import apply_tofre, apply_to_spatial, apply_ksu_kernel +from utils.utils import * + + +num_timesteps = args.num_timesteps +image_size = args.image_size +distortion_sigma = 10 / 255 +use_kspace = args.use_kspace +use_time_model = args.use_time_model +DEBUG = args.DEBUG +snapshot_path=args.snapshot_path + +kspace_masks = np.load(f"./dataloaders/example_mask/m4raw_{args.ACCELERATIONS[0]}_mask.npy") +kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + +test_sample = args.test_sample # Ksample | ColdDiffusion | DDPM +frequency_distortion = True + +@torch.no_grad() +def evaluate(model, data_loader, device, save_path): + os.makedirs(save_path, exist_ok=True) + + model.eval() + nmse_meter = [] + psnr_meter = [] + ssim_meter = [] + nmse_meter_all = [] + psnr_meter_all = [] + ssim_meter_all = [] + output_dic = {} # defaultdict(dict) + target_dic = {} # efaultdict(dict) + input_dic = {} # defaultdict(dict) + + flag = 0 + last_name = 'no' + + print("len of data_loader: ", len(data_loader)) + + for sampled_batch in tqdm(data_loader): + t1_img, t1_in = sampled_batch['t1'], sampled_batch['t1_in'] + t2_img, t2_in = sampled_batch['t2'], sampled_batch['t2_in'] + + t1_img = t1_img.to(device) + t1_in = t1_in.to(device) + t2_img = t2_img.to(device) + t2_in = t2_in.to(device) + + mean, std = sampled_batch['t2_mean'], sampled_batch['t2_std'] + + name = sampled_batch['fname'] + fname = [name] + slice_num = sampled_batch['slice'] + + mean = mean.unsqueeze(1).unsqueeze(2).to(device) + std = std.unsqueeze(1).unsqueeze(2).to(device) + + t2_in_origin = t2_in.clone() + + # Degradation + if use_kspace: + b = 1 + t = torch.randint(num_timesteps - 1, num_timesteps, (b,), device=device).long() # t-1 + mask = kspace_masks[t] + fft, mask = apply_tofre(t2_in.clone(), mask) # t2_img + fft = fft * mask + 0.0 + t2_in = apply_to_spatial(fft) + t2_in_origin = t2_in.clone() + + while t >= 0: + # outputs = model(t2_in, t1_img)['img_out'] + if use_time_model: + outputs = model(t2_in, t1_img, t)['img_out'] + else: + outputs = model(t2_in, t1_img)['img_out'] + + if t == 0: + mask = kspace_masks[0] # last one + t2_in = outputs + + else: + if test_sample == "Ksample": # Ksample | ColdDiffusion | DDPM + k_full = kspace_masks[-1] + faded_recon_sample_fre, k_full = apply_tofre(t2_in, k_full) + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] + kt = kspace_masks[t] + k_residual = kt_sub_1 - kt + + recon_sample_fre, k_residual = apply_tofre(outputs, k_residual) + fre_amend = recon_sample_fre * k_residual + faded_recon_sample_fre = faded_recon_sample_fre + fre_amend + + outputs = apply_to_spatial(faded_recon_sample_fre) + t2_in = outputs + + + elif test_sample == "ColdDiffusion": + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] + kt = kspace_masks[t] + + x_t_hat = apply_ksu_kernel(outputs, kt) + x_t_sub_1_hat = apply_ksu_kernel(outputs, kt_sub_1) + + outputs = t2_in - x_t_hat + x_t_sub_1_hat + + t2_in = outputs + + elif test_sample == "DDPM": + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] # get_kspace_kernels(t - 2).cuda() + outputs = apply_ksu_kernel(kt_sub_1, kt_sub_1) + t2_in = outputs + + + + + t = t - 1 + + else: + outputs = model(t2_in, t1_img)['img_out'] + + # print("outputs shape: ", outputs.shape, outputs.min(), outputs.max()) + # print("t2_img shape: ", t2_img.shape, t2_img.min(), t2_img.max()) + + target = t2_img.clone().squeeze(1) * std + mean + inputs = t2_in_origin.clone().squeeze(1) * std + mean + outputs_save = outputs.clone().squeeze(1) * std + mean + + outputs_save = outputs_save.cpu().numpy() + # outputs_save = np.clip(outputs_save, a_min=-1, a_max=1) + target_save = target.cpu().numpy() + in_save = inputs.cpu().numpy() + + _min, _max = target_save.min(), target_save.max() + target_save = (((target_save - _min) / (_max - _min)) * 255).astype(np.uint8) + in_save = (((in_save - _min) / (_max - _min)) * 255).astype(np.uint8) + outputs_save = (((outputs_save - _min) / (_max - _min)) * 255).astype(np.uint8) + + # Not sure if it was correct to convert to ubyte + outputs_save = img_as_ubyte(outputs_save) + target_save = img_as_ubyte(target_save) + in_save = img_as_ubyte(in_save) + + # print("outputs_save shape: ", outputs_save.shape, outputs_save.min(), outputs_save.max()) + # print("target_save shape: ", target_save.shape, target_save.min(), target_save.max()) + # print("in_save shape: ", in_save.shape, in_save.min(), in_save.max()) + + if len(outputs_save.shape) > 3: + outputs_save = outputs_save.squeeze(0) + target_save = target_save.squeeze(0) + in_save = in_save.squeeze(0) + + if len(outputs_save.shape) > 3: + outputs_save = outputs_save.squeeze(0) + target_save = target_save.squeeze(0) + in_save = in_save.squeeze(0) + + name = name[0].numpy() + name_int = int(name) + + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '.png', target_save) + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '_in.png', in_save) + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '_out.png', outputs_save) + + outputs = outputs.squeeze(1) * std + mean + target = t2_img.squeeze(1) * std + mean + inputs = t2_in_origin.squeeze(1) * std + mean + + if name_int not in output_dic.keys(): + output_dic[name_int] = [] + target_dic[name_int] = [] + input_dic[name_int] = [] + + output_dic[name_int].append(outputs[0]) + target_dic[name_int].append(target[0]) + input_dic[name_int].append(inputs[0]) + + # print("target/outputs shape: ", target.shape, outputs.shape) + our_nmse = nmse(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + our_psnr = psnr(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + our_ssim = ssim(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + + print(' name:{}, slice:{}, nmse:{}, psnr:{}, ssim:{}'.format(name, slice_num[0], our_nmse, our_psnr, our_ssim)) + + nmse_meter_all.append(our_nmse) + psnr_meter_all.append(our_psnr) + ssim_meter_all.append(our_ssim) + # print("psnr_meter_all: ", np.mean(psnr_meter_all)) + + for name in output_dic.keys(): + print("name: ", name, len(output_dic[name])) + # f_output = torch.stack([v for _, v in output_dic[name].items()]) + # f_target = torch.stack([v for _, v in target_dic[name].items()]) + f_output = torch.stack(list(output_dic[name])) + f_target = torch.stack(list(target_dic[name])) + + print("f_output shape: ", f_output.shape) + + if len(f_output.shape) > 3: + f_output = f_output.squeeze(1) + f_target = f_target.squeeze(1) + + our_nmse = nmse(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_psnr = psnr(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_ssim = ssim(f_target.cpu().numpy(), f_output.cpu().numpy()) + + nmse_meter.append(our_nmse) + psnr_meter.append(our_psnr) + ssim_meter.append(our_ssim) + + nmse_meter_score = np.array(nmse_meter) + psnr_meter_score = np.array(psnr_meter) + ssim_meter_score = np.array(ssim_meter) + + nmse_meter_all_score = np.array(nmse_meter_all) + psnr_meter_all_score = np.array(psnr_meter_all) + ssim_meter_all_score = np.array(ssim_meter_all) + + print("===> Evaluate Metric <===") + print("Results") + print("-" * 36) + print(f"{test_sample} NMSE: {np.mean(nmse_meter_score) * 100:.4f} ± {np.std(nmse_meter_score) * 100:.4f}") + print(f"{test_sample} PSNR: {np.mean(psnr_meter_score):.4f} ± {np.std(psnr_meter_score):.4f}") + print(f"{test_sample} SSIM: {np.mean(ssim_meter_score):.4f} ± {np.std(ssim_meter_score):.4f}") + print("-" * 36) + print(f"All NMSE: {np.mean(nmse_meter_all_score) * 100:.4f} ± {np.std(nmse_meter_all_score) * 100:.4f}") + print(f"All PSNR: {np.mean(psnr_meter_all_score):.4f} ± {np.std(psnr_meter_all_score):.4f}") + print(f"All SSIM: {np.mean(ssim_meter_all_score):.4f} ± {np.std(ssim_meter_all_score):.4f}") + print("-" * 36) + print(f"Save Path: {save_path}") + + model.train() + return {'NMSE': np.mean(nmse_meter_score), 'PSNR': np.mean(psnr_meter_score), 'SSIM': np.mean(ssim_meter_score)} + + +from dataloaders.m4raw_std_dataloader import M4Raw_TestSet as M4Raw_TestSet_new, M4Raw_TrainSet as M4Raw_TrainSet_new + +from dataloaders.m4raw_dataloader import M4Raw_TestSet, M4Raw_TrainSet + + + + +if __name__ == "__main__": + + + if use_time_model: + network = DiffTwoBranch(args).cuda() + else: + network = TwoBranch(args).cuda() + + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + + if use_new_dataloader: + db_test = M4Raw_TestSet_new(args, use_kspace=use_kspace) # + else: + db_test = M4Raw_TestSet(args.root_path, args.MRIDOWN, use_kspace=use_kspace) + + # db_test = build_dataset(args, mode='val', use_kspace=use_kspace) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + if args.phase == 'test': + + save_mode_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + # save_mode_path = os.path.join(snapshot_path, 'iter_100000.pth') + print('load weights from ' + save_mode_path) + + try: + checkpoint = torch.load(save_mode_path) + except: + print("Missing keys:", set(model_state_dict.keys()) - set(loaded_state_dict.keys())) + + + weights_dict = {} + for k, v in checkpoint['network'].items(): + new_k = k.replace('module.', '') if 'module' in k else k + weights_dict[new_k] = v + + network.load_state_dict(weights_dict) + network.eval() + + eval_result = evaluate(network, testloader, device, save_path=snapshot_path + '/result_case/') + diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/train_brats.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/train_brats.py new file mode 100644 index 0000000000000000000000000000000000000000..43fb4c73be16ec6dca3b66026087d9ad5c4d6244 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/train_brats.py @@ -0,0 +1,398 @@ +from tqdm import tqdm +from tensorboardX import SummaryWriter +import logging, time, os, sys +import torch.optim as optim +from torchvision import transforms +from torch.utils.data import DataLoader +from dataloaders.BRATS_dataloader_new import Hybrid as MyDataset +from dataloaders.BRATS_dataloader_new import RandomPadCrop, ToTensor +from networks.mynet import TwoBranch +from utils.option import args +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + +from utils.utils import * +from frequency_diffusion.degradation.k_degradation import get_ksu_kernel, apply_tofre, apply_to_spatial +from networks_time.mynet import DiffTwoBranch + +train_data_path = args.root_path +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +batch_size = args.batch_size * len(args.gpu.split(',')) +max_iterations = args.max_iterations +base_lr = args.base_lr + +# --use_time_model True --use_kspace True --ACCELERATIONS 4 --MRIDOWN 4X --low_field_SNR 20 --input_normalize mean_std +DEBUG = args.DEBUG +use_time_model = args.use_time_model +use_kspace = args.use_kspace # PSNR: 30.138548551934974 average SSIM: 0.770964106980312 + # PSNR: 31.325274490046855 average SSIM: 0.8589609042898623 4X if not + + # PSNR: 29.846184815515585 average SSIM: 0.8797758188214125 -> 31.18, 0.77 + # PSNR: 28.494279128317515 average SSIM: 0.8179950512965841 8X if not + +# kspace_refine = True # Albu with 41.33, w/ 42.00 +# mask_vacant = False +frequency_distortion = True + + + +num_timesteps = args.num_timesteps #30 +image_size = args.image_size #240 +distortion_sigma = 10/255 + +if args.MRIDOWN == "4X": + accelerate_mask = np.load("./dataloaders/example_mask/brats_4X_mask.npy") + accelerate_mask = torch.from_numpy(accelerate_mask).unsqueeze(0).clone().float() +else: + accelerate_mask = None + +# Output a list of k-space kernels +kspace_masks = get_ksu_kernel(num_timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=args.ACCELERATIONS[0], + accelerate_mask=accelerate_mask + ) +np.save(f"./dataloaders/example_mask/brats_{args.ACCELERATIONS[0]}_kspace_mask.npy", kspace_masks) + +kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + + + +if __name__ == "__main__": + ## make logger file + if use_kspace: + snapshot_path = snapshot_path.rstrip("/") + f'_t{num_timesteps}_kspace/' + + if use_time_model: + snapshot_path = snapshot_path.rstrip("/") + '_time/' + + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + if use_time_model: + network = DiffTwoBranch(args).cuda() + else: + network = TwoBranch(args).cuda() + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + + n_parameters = sum(p.numel() for p in network.parameters() if p.requires_grad) + print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + db_train = MyDataset(split='train', MRIDOWN=args.MRIDOWN, SNR=args.low_field_SNR, + transform=transforms.Compose([RandomPadCrop(), ToTensor()]), + base_dir=train_data_path, input_normalize = args.input_normalize, use_kspace=use_kspace) + + db_test = MyDataset(split='test', MRIDOWN=args.MRIDOWN, SNR=args.low_field_SNR, + transform=transforms.Compose([ToTensor()]), + base_dir=test_data_path, input_normalize = args.input_normalize) + + trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + fixtrainloader = DataLoader(db_train, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + if args.phase == 'train': + network.train() + + params = list(network.parameters()) + optimizer1 = optim.AdamW(params, lr=base_lr, betas=(0.9, 0.999), weight_decay=1e-4) + if not use_kspace: + scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=20000, gamma=0.5) + else: + scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=40000, gamma=0.5) + + writer = SummaryWriter(snapshot_path + '/log') + + iter_num = 0 + max_epoch = max_iterations // len(trainloader) + 1 + if use_kspace: + max_epoch = max_epoch * num_timesteps + + best_status = {'T1_NMSE': 10000000, 'T1_PSNR': 0, 'T1_SSIM': 0, + 'T2_NMSE': 10000000, 'T2_PSNR': 0, 'T2_SSIM': 0} + + fft_weight = 0.01 + criterion = nn.L1Loss().to(device, non_blocking=True) + freloss = Frequency_Loss().to(device, non_blocking=True) + start_time = time.time() + mask = None + + for epoch_num in tqdm(range(max_epoch), ncols=70): + time1 = time.time() + debug_time = False + + for i_batch, (sampled_batch, sample_stats) in enumerate(trainloader): + time2 = time.time() + + t1_in, t1, t2_in, t2 = sampled_batch['image_in'].cuda(), sampled_batch['image'].cuda(), \ + sampled_batch['target_in'].cuda(), sampled_batch['target'].cuda() + + t1_krecon, t2_krecon = sampled_batch['image_krecon'].cuda(), sampled_batch['target_krecon'].cuda() + + # Degradation + if use_kspace: + b = t1_in.shape[0] + t = torch.randint(0, num_timesteps, (b,), device=device).long() + mask = kspace_masks[t] + + target_fft, _ = apply_tofre(t2.clone(), mask) + fft, mask = apply_tofre(t2_in.clone(), mask) + + # if np.random.rand() > (1 / (1 + num_timesteps)): + fft = target_fft * mask + fft * (1 - mask) # Seems too easy + + # Frequency Noise + if frequency_distortion: + fft_magnitude = torch.abs(fft) # 幅度 + fft_phase = torch.angle(fft) # 相位 + + # Add noise to unmasked frequencies to maintain the stochasticity that diffusion models typically rely on. + sigma = distortion_sigma * torch.abs(torch.randn(1)).item() + noise = torch.randn_like(fft_magnitude) * sigma + noise_magnitude = noise * fft_magnitude * mask # + noise * (1 - mask) + fft_magnitude += noise_magnitude + + sigma = distortion_sigma / 2 * torch.abs(torch.randn(1)).item() + noise = torch.randn_like(fft_phase) * sigma + noise_pha = noise * fft_phase * mask # + noise * (1 - mask) + fft_phase += noise_pha + + fft = fft_magnitude * torch.exp(1j * fft_phase) + + t2_in = apply_to_spatial(fft) + + time3 = time.time() + + if use_time_model and use_kspace: + outputs = network(t2_in, t1, t) + else: + outputs = network(t2_in, t1) + + loss = criterion(outputs['img_out'], t2) + criterion(outputs['img_fre'], t2) + \ + fft_weight * freloss(outputs['img_fre'], t2, mask) + + + time4 = time.time() + optimizer1.zero_grad() + loss.backward() + + if args.clip_grad == "True": + ### clip the gradients to a small range. + torch.nn.utils.clip_grad_norm_(network.parameters(), 0.01) + + optimizer1.step() + scheduler1.step() + if debug_time: + print("Optimizer Step Time: ", time.time() - time2) + + time5 = time.time() + + # summary + iter_num = iter_num + 1 + + if iter_num % 100 == 0: + logging.info('iteration %d [%.2f sec]: learning rate : %f loss : %f ' % (iter_num, time.time()-start_time, scheduler1.get_lr()[0], loss.item())) + if DEBUG: + break + + if iter_num % 20000 == 0: + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') + torch.save({'network': network.state_dict()}, save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + + if iter_num > max_iterations: + break + time1 = time.time() + + + ## ================ Evaluate ================ + logging.info(f'Epoch {epoch_num} Evaluation:') + # print() + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all = [], [], [] + + t2_MSE_first_step, t2_PSNR_first_step, t2_SSIM_first_step = [], [], [] + + t1_MSE_krecon, t1_PSNR_krecon, t1_SSIM_krecon = [], [], [] + t2_MSE_krecon, t2_PSNR_krecon, t2_SSIM_krecon = [], [], [] + ids = 0 + for (sampled_batch, sample_stats) in testloader: + + t1_in, t1, t2_in, t2 = sampled_batch['image_in'].cuda(), sampled_batch['image'].cuda(), \ + sampled_batch['target_in'].cuda(), sampled_batch['target'].cuda() + + t1_krecon, t2_krecon = sampled_batch['image_krecon'].cuda(), sampled_batch['target_krecon'].cuda() + t_merge = torch.cat([t1_in, t2_in], dim=1) + + + if use_kspace: + t = num_timesteps - 1 + mask = kspace_masks[t] + target_fft, _ = apply_tofre(t2.clone(), mask) + fft, mask = apply_tofre(t2_in.clone(), mask) + + fft = target_fft * mask + fft * (1 - mask) # Seems too easy + t2_in = apply_to_spatial(fft) + + while t >= 0: + if use_time_model: + outputs = network(t2_in, t1, t)['img_out'] + else: + outputs = network(t2_in, t1)['img_out'] + + if t == num_timesteps - 1: + first_step_recon = outputs + + if t == 0: + mask = kspace_masks[0] # last one + t2_in = outputs + + else: + k_full = kspace_masks[-1] # True + t2_in_fre, k_full = apply_tofre(t2_in, k_full) + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] # get_kspace_kernels(t - 2).cuda() + kt = kspace_masks[t] # current one + k_residual = kt_sub_1 - kt + + recon_sample_fre, k_residual = apply_tofre(outputs, k_residual) + + # fft = target_fft * mask + fft * (1 - mask) + t2_in_fre = t2_in_fre * (1 - k_residual) + recon_sample_fre * k_residual # substitute + + outputs = apply_to_spatial(t2_in_fre) + t2_in = outputs + + t = t - 1 + t2_out = t2_in + else: + t2_out = network(t2_in, t1)['img_out'] + + t1_out = None + + if args.input_normalize == "mean_std": + t1_mean = sample_stats['t1_mean'].data.cpu().numpy()[0] + t1_std = sample_stats['t1_std'].data.cpu().numpy()[0] + t2_mean = sample_stats['t2_mean'].data.cpu().numpy()[0] + t2_std = sample_stats['t2_std'].data.cpu().numpy()[0] + + if t1_out is not None: + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_out_img = (np.clip(t1_out.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_krecon_img = (np.clip(t1_krecon.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + + + t2_img = (np.clip(t2.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_out_img = (np.clip(t2_out.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_krecon_img = (np.clip(t2_krecon.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_first_step_recon_img = (np.clip(first_step_recon.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + + else: + if t1_out is not None: + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t1_out_img = (np.clip(t1_out.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t1_krecon_img = (np.clip(t1_krecon.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + + t2_img = (np.clip(t2.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t2_out_img = (np.clip(t2_out.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t2_krecon_img = (np.clip(t2_krecon.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t2_first_step_recon_img = (np.clip(first_step_recon.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + + if t1_out is not None: + + MSE = mean_squared_error(t1_img, t1_out_img) + PSNR = peak_signal_noise_ratio(t1_img, t1_out_img) + SSIM = structural_similarity(t1_img, t1_out_img) + t1_MSE_all.append(MSE) + t1_PSNR_all.append(PSNR) + t1_SSIM_all.append(SSIM) + + MSE = mean_squared_error(t1_img, t1_krecon_img) + PSNR = peak_signal_noise_ratio(t1_img, t1_krecon_img) + SSIM = structural_similarity(t1_img, t1_krecon_img) + t1_MSE_krecon.append(MSE) + t1_PSNR_krecon.append(PSNR) + t1_SSIM_krecon.append(SSIM) + + + if t2_out is not None: + MSE = mean_squared_error(t2_img, t2_out_img) + PSNR = peak_signal_noise_ratio(t2_img, t2_out_img) + SSIM = structural_similarity(t2_img, t2_out_img) + t2_MSE_all.append(MSE) + t2_PSNR_all.append(PSNR) + t2_SSIM_all.append(SSIM) + # print("[t2 MRI] MSE:", MSE, "PSNR:", PSNR, "SSIM:", SSIM) + + MSE = mean_squared_error(t2_img, t2_first_step_recon_img) + PSNR = peak_signal_noise_ratio(t2_img, t2_first_step_recon_img) + SSIM = structural_similarity(t2_img, t2_first_step_recon_img) + t2_MSE_first_step.append(MSE) + t2_PSNR_first_step.append(PSNR) + t2_SSIM_first_step.append(SSIM) + + MSE = mean_squared_error(t2_img, t2_krecon_img) + PSNR = peak_signal_noise_ratio(t2_img, t2_krecon_img) + SSIM = structural_similarity(t2_img, t2_krecon_img) + t2_MSE_krecon.append(MSE) + t2_PSNR_krecon.append(PSNR) + t2_SSIM_krecon.append(SSIM) + + ids += 1 + if ids > 100: + break + + if t1_out is not None: + t1_mse = np.array(t1_MSE_all).mean() + t1_psnr = np.array(t1_PSNR_all).mean() + t1_ssim = np.array(t1_SSIM_all).mean() + + t1_krecon_mse = np.array(t1_MSE_krecon).mean() + t1_krecon_psnr = np.array(t1_PSNR_krecon).mean() + t1_krecon_ssim = np.array(t1_SSIM_krecon).mean() + + t2_mse = np.array(t2_MSE_all).mean() + t2_psnr = np.array(t2_PSNR_all).mean() + t2_ssim = np.array(t2_SSIM_all).mean() + + t2_first_step_mse = np.array(t2_MSE_first_step).mean() + t2_first_step_psnr = np.array(t2_PSNR_first_step).mean() + t2_first_step_ssim = np.array(t2_SSIM_first_step).mean() + + t2_krecon_mse = np.array(t2_MSE_krecon).mean() + t2_krecon_psnr = np.array(t2_PSNR_krecon).mean() + t2_krecon_ssim = np.array(t2_SSIM_krecon).mean() + + + if t2_psnr > best_status['T2_PSNR']: + best_status = {'T2_NMSE': t2_mse, 'T2_PSNR': t2_psnr, 'T2_SSIM': t2_ssim} + best_checkpoint_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + + torch.save({'network': network.state_dict()}, best_checkpoint_path) + print('New Best Network:') + + logging.info(f"[T2 First MRI:] average MSE: {t2_first_step_mse} average PSNR: {t2_first_step_psnr} average SSIM: {t2_first_step_ssim}") + logging.info(f"[T2 MRI:] average MSE: {t2_mse} average PSNR: {t2_psnr} average SSIM: {t2_ssim}") + print("Snapshot_path = ", snapshot_path) + + if iter_num > max_iterations: + break + + print(best_status) + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(max_iterations) + '.pth') + torch.save({'network': network.state_dict()}, + save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + writer.close() diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/train_fastmri.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/train_fastmri.py new file mode 100644 index 0000000000000000000000000000000000000000..7123371e8cd51c1b7a121ee07afe81f43a68db78 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/train_fastmri.py @@ -0,0 +1,364 @@ +import os +import sys +from tqdm import tqdm +from tensorboardX import SummaryWriter + +import logging +import time +import torch.optim as optim +from torch.utils.data import DataLoader +from networks.mynet import TwoBranch +from networks_time.mynet import DiffTwoBranch + +from utils.option import args + +from dataloaders.fastmri import build_dataset +from frequency_diffusion.degradation.k_degradation import get_ksu_kernel, apply_tofre, apply_to_spatial +from utils.lpips import LPIPS +from utils.metric import nmse, psnr, ssim, AverageMeter +from collections import defaultdict + +train_data_path = args.root_path +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +batch_size = args.batch_size * len(args.gpu.split(',')) +max_iterations = args.max_iterations +base_lr = args.base_lr +from utils.utils import * + +# --num_timesteps 30 --image_size 320 --use_kspace True --use_time_model True --ACCELERATIONS 4 --gpu 0 --phase train +DEBUG = False +use_kspace = args.use_kspace +frequency_distortion = False +use_time_model = args.use_time_model +num_timesteps = args.num_timesteps +image_size = 320 +distortion_sigma = 10 / 255 + +if args.phase == 'test': + kspace_masks = np.load(f"./dataloaders/example_mask/kspace_{args.ACCELERATIONS[0]}_mask.npy") + kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + +else: + # Output a list of k-space kernels + kspace_masks = get_ksu_kernel(num_timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=args.ACCELERATIONS[0], + ) # args.ACCELERATIONS = [4] or [8] + + np.save(f"./dataloaders/example_mask/kspace_{args.ACCELERATIONS[0]}_mask.npy", kspace_masks) +kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + +print("kspace kernels shape:", kspace_masks.shape) # (1, 1, 320, 320) + + +@torch.no_grad() +def evaluate(model, data_loader, device): + model.eval() + + nmse_meter, psnr_meter, ssim_meter = AverageMeter(), AverageMeter(), AverageMeter() + direct_nmse, direct_psnr, direct_ssim = AverageMeter(), AverageMeter(), AverageMeter() + output_dic = defaultdict(dict) + target_dic = defaultdict(dict) + input_dic = defaultdict(dict) + direct_dic = defaultdict(dict) + + for id, data in enumerate(data_loader): + pd, pdfs, _ = data + name = os.path.basename(pdfs[4][0]).split('.')[0] + + target = pdfs[1].to(device) + mean, std = pdfs[2], pdfs[3] + + fname = pdfs[4] + slice_num = pdfs[5] + + mean = mean.unsqueeze(1).unsqueeze(2).to(device) + std = std.unsqueeze(1).unsqueeze(2).to(device) + + pd_img = pd[1].unsqueeze(1).to(device) + pdfs_img = pdfs[0].unsqueeze(1).to(device) + + pdfs_img_origin = pdfs_img.clone() + + # Degradation + if use_kspace: + b = pd_img.size(0) + t = torch.randint(num_timesteps - 1, num_timesteps, (b,), device=device).long() # t-1 + mask = kspace_masks[t] + fft, mask = apply_tofre(target.clone(), mask) + fft = fft * mask + 0.0 + pdfs_img = apply_to_spatial(fft) + # print("pdfs_img shape:", pdfs_img.shape, pdfs_img.min(), pdfs_img.max()) + + while t >= 0: + if use_time_model: + outputs = model(pdfs_img, pd_img, t)['img_out'] + else: + outputs = model(pdfs_img, pd_img)['img_out'] + + if t == num_timesteps - 1: + direct_recon = outputs + + if t == 0: + mask = kspace_masks[0] # last one + pdfs_img = outputs + + else: + k_full = kspace_masks[-1] + faded_recon_sample_fre, k_full = apply_tofre(pdfs_img, k_full) + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] # get_kspace_kernels(t - 2).cuda() + kt = kspace_masks[t] # self.get_kspace_kernels(t - 1).cuda() # last one + k_residual = kt_sub_1 - kt + + recon_sample_fre, k_residual = apply_tofre(outputs, k_residual) + fre_amend = recon_sample_fre * k_residual + faded_recon_sample_fre = faded_recon_sample_fre + fre_amend # * (1-k_residual) + # faded_recon_sample_fre = faded_recon_sample_fre * kt_sub_1 + outputs = apply_to_spatial(faded_recon_sample_fre) + pdfs_img = outputs + + t = t - 1 + + else: + outputs = model(pdfs_img, pd_img)['img_out'] + + target = target * std + mean + inputs = pdfs_img_origin.squeeze(1) * std + mean + outputs = outputs.squeeze(1) * std + mean + direct_recon = direct_recon.squeeze(1) * std + mean + + # print("target/outputs shape: ", target.shape, outputs.shape) + # our_nmse = nmse(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + # our_psnr = psnr(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + # our_ssim = ssim(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + + # print('name:{}, slice:{}, nmse:{}, psnr:{}, ssim:{}'.format(name, slice_num[0], our_nmse, our_psnr, our_ssim)) + + for i, f in enumerate(fname): + output_dic[f][slice_num[i]] = outputs[i] + target_dic[f][slice_num[i]] = target[i] + input_dic[f][slice_num[i]] = inputs[i] + direct_dic[f][slice_num[i]] = direct_recon[i] + + if id > 100: + break + + for name in output_dic.keys(): + f_output = torch.stack([v for _, v in output_dic[name].items()]) + f_target = torch.stack([v for _, v in target_dic[name].items()]) + our_nmse = nmse(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_psnr = psnr(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_ssim = ssim(f_target.cpu().numpy(), f_output.cpu().numpy()) + + nmse_meter.update(our_nmse, 1) + psnr_meter.update(our_psnr, 1) + ssim_meter.update(our_ssim, 1) + + direct_nmse.update( + nmse(f_target.cpu().numpy(), torch.stack([v for _, v in direct_dic[name].items()]).cpu().numpy()), 1) + direct_psnr.update( + psnr(f_target.cpu().numpy(), torch.stack([v for _, v in direct_dic[name].items()]).cpu().numpy()), 1) + direct_ssim.update( + ssim(f_target.cpu().numpy(), torch.stack([v for _, v in direct_dic[name].items()]).cpu().numpy()), 1) + + print("==> Evaluate Metric") + print("Direct Results ----------") + print("NMSE: {:.4}".format(direct_nmse.avg)) + print("PSNR: {:.4}".format(direct_psnr.avg)) + print("SSIM: {:.4}".format(direct_ssim.avg)) + print("------------------") + + print("==> Evaluate Metric") + print("Results ----------") + print("NMSE: {:.4}".format(nmse_meter.avg)) + print("PSNR: {:.4}".format(psnr_meter.avg)) + print("SSIM: {:.4}".format(ssim_meter.avg)) + print("------------------") + model.train() + + return {'NMSE': nmse_meter.avg, 'PSNR': psnr_meter.avg, 'SSIM': ssim_meter.avg} + + +if __name__ == "__main__": + ## make logger file + if use_kspace: + snapshot_path = snapshot_path.rstrip("/") + f'_t{num_timesteps}_kspace/' + if use_time_model: + snapshot_path = snapshot_path.rstrip("/") + '_time/' + + if not frequency_distortion: + snapshot_path = snapshot_path.rstrip("/") + '_no_distortion/' + + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + if use_time_model: + network = DiffTwoBranch(args).cuda() + else: + network = TwoBranch(args).cuda() + # network = build_model_from_name(args).cuda() + device = torch.device('cuda') + network.to(device) + lpips_loss = LPIPS().eval().to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + # network = nn.SyncBatchNorm.convert_sync_batchnorm(network) + + n_parameters = sum(p.numel() for p in network.parameters() if p.requires_grad) + print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + db_train = build_dataset(args, mode='train', use_kspace=use_kspace) + db_test = build_dataset(args, mode='val', use_kspace=use_kspace) + + trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + mask = None + + if args.phase == 'train': + network.train() + + params = list(network.parameters()) + optimizer1 = optim.AdamW(params, lr=base_lr, betas=(0.9, 0.999), weight_decay=1e-4) + scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=20000, gamma=0.5) + + writer = SummaryWriter(snapshot_path + '/log') + + iter_num = 0 + max_epoch = max_iterations // len(trainloader) + 1 + + best_status = {'NMSE': 10000000, 'PSNR': 0, 'SSIM': 0} + fft_weight = 0.01 + criterion = nn.L1Loss().to(device, non_blocking=True) + freloss = Frequency_Loss().to(device, non_blocking=True) + + for epoch_num in tqdm(range(max_epoch), ncols=70): + time1 = time.time() + start_time = time.time() + for i_batch, sampled_batch in enumerate(trainloader): + time2 = time.time() + + pd, pdfs, _ = sampled_batch + target = pdfs[1] + + mean, std = pdfs[2], pdfs[3] + + pd_img = pd[1].unsqueeze(1) + pdfs_img = pdfs[0].unsqueeze(1) + target = target.unsqueeze(1) + + b = pd_img.size(0) + + pd_img = pd_img.to(device) # [4, 1, 320, 320] + pdfs_img = pdfs_img.to(device) # [4, 1, 320, 320] + target = target.to(device) # [4, 1, 320, 320] + + time3 = time.time() + + # Degradation + if use_kspace: + t = torch.randint(0, num_timesteps, (b,), device=device).long() + mask = kspace_masks[t] + + fft, mask = apply_tofre(target.clone(), mask) + fft = fft * mask + + # Frequency Noise + if frequency_distortion: + fft_magnitude = torch.abs(fft) # 幅度 + fft_phase = torch.angle(fft) # 相位 + + # Add noise to unmasked frequencies to maintain the stochasticity that diffusion models typically rely on. + sigma = distortion_sigma * torch.abs(torch.randn(1)).item() + noise = torch.randn_like(fft_magnitude) * sigma + noise_magnitude = noise * fft_magnitude * mask # + noise * (1 - mask) + fft_magnitude += noise_magnitude + + sigma = distortion_sigma / 2 * torch.abs(torch.randn(1)).item() + noise = torch.randn_like(fft_phase) * sigma + noise_pha = noise * fft_phase * mask # + noise * (1 - mask) + fft_phase += noise_pha + + fft = fft_magnitude * torch.exp(1j * fft_phase) + + pdfs_img = apply_to_spatial(fft) + + # breakpoint() + if use_time_model: + outputs = network(pdfs_img, pd_img, t) + else: + outputs = network(pdfs_img, pd_img) + + loss = criterion(outputs['img_out'], target) + \ + fft_weight * freloss(outputs['img_fre'], target, mask) + \ + criterion(outputs['img_fre'], target) + \ + 0.01 * lpips_loss(outputs['img_out'], target).mean() + + time4 = time.time() + + optimizer1.zero_grad() + loss.backward() + + if args.clip_grad == "True": + ### clip the gradients to a small range. + torch.nn.utils.clip_grad_norm_(network.parameters(), 0.01) + + optimizer1.step() + scheduler1.step() + + time5 = time.time() + + # summary + iter_num = iter_num + 1 + + print_iter = 100 # if not DEBUG else 5 + if iter_num % print_iter == 0: + logging.info('iteration %d [%.2f sec]: learning rate : %f loss : %f ' % ( + iter_num, time.time() - start_time, scheduler1.get_lr()[0], loss.item())) + if DEBUG: + break + + if iter_num % 20000 == 0: + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') + torch.save({'network': network.state_dict()}, save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + + if iter_num > max_iterations: + break + time1 = time.time() + + ## ================ Evaluate ================ + logging.info(f'Epoch {epoch_num} Evaluation:') + # print() + eval_result = evaluate(network, testloader, device) + + if eval_result['PSNR'] > best_status['PSNR']: + best_status = {'NMSE': eval_result['NMSE'], 'PSNR': eval_result['PSNR'], 'SSIM': eval_result['SSIM']} + best_checkpoint_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + + torch.save({'network': network.state_dict()}, best_checkpoint_path) + print('New Best Network saved:', best_checkpoint_path) + + logging.info( + f"average MSE: {eval_result['NMSE']} average PSNR: {eval_result['PSNR']} average SSIM: {eval_result['SSIM']}") + print("Snapshot Path: ", snapshot_path) + + if iter_num > max_iterations: + break + + print(best_status) + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(max_iterations) + '.pth') + torch.save({'network': network.state_dict()}, + save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + writer.close() diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/train_m4raw.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/train_m4raw.py new file mode 100644 index 0000000000000000000000000000000000000000..f829de6c4c16f2b64d6b8fd85593e00abea9951f --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/train_m4raw.py @@ -0,0 +1,450 @@ +import os +import sys +from tqdm import tqdm +from tensorboardX import SummaryWriter +import logging +import time +import torch.optim as optim +from torch.utils.data import DataLoader +from networks.mynet import TwoBranch +from networks_time.mynet import DiffTwoBranch + +from utils.option import args +import matplotlib.pyplot as plt + +use_new_dataloader = True + +if use_new_dataloader: + from dataloaders.m4raw_std_dataloader import M4Raw_TestSet, M4Raw_TrainSet, normalize, normalize_instance_dim +else: + from dataloaders.m4raw_dataloader import M4Raw_TestSet, M4Raw_TrainSet, normalize, normalize_instance_dim + +from frequency_diffusion.degradation.k_degradation import get_ksu_kernel, apply_tofre, apply_to_spatial +from utils.lpips import LPIPS +from utils.metric import nmse, psnr, ssim, AverageMeter +from collections import defaultdict +# import imsave +from skimage.io import imsave + +train_data_path = args.root_path +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +batch_size = args.batch_size * len(args.gpu.split(',')) +max_iterations = args.max_iterations +base_lr = args.base_lr +from utils.utils import * + +frequency_distortion = True + + + +num_timesteps = args.num_timesteps +image_size = args.image_size +distortion_sigma = 10/255 +DEBUG = args.DEBUG +use_kspace = args.use_kspace +use_time_model = args.use_time_model + + +# Baseline ------------------------------------ +# NMSE: 3.3329 ± 0.4915 +# PSNR: 34.0418 ± 0.6790 +# SSIM: 0.8699 ± 0.0143 +# ------------------------------------ +# Save Path: model/FSMNet_m4raw_4x//result_case/ + +# ------------------------------------ +# NMSE: 3.1966 ± 0.3939 +# PSNR: 34.2115 ± 0.5978 +# SSIM: 0.8910 ± 0.0118 +# ------------------------------------ +# Save Path: model/FSMNet_m4raw_4x_kspace//result_case/ + + +# num_timesteps = 5 +image_size = 240 + +# if args.phase == 'test': +# kspace_masks = np.load(f"./dataloaders/example_mask/m4raw_{args.ACCELERATIONS[0]}_mask.npy") +# kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() +# +# else: + +use_in_mean_std = False + +# Output a list of k-space kernels +kspace_masks = get_ksu_kernel(num_timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=args.ACCELERATIONS[0], + ) # args.ACCELERATIONS = [4] or [8] + +np.save(f"./dataloaders/example_mask/m4raw_{args.ACCELERATIONS[0]}_mask.npy", kspace_masks) + +kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + +print("kspace kernels shape:", kspace_masks.shape) # (1, 1, 320, 320) + + +@torch.no_grad() +def evaluate(model, data_loader, device): + model.eval() + print_i = 1 + + nmse_meter = AverageMeter() + psnr_meter = AverageMeter() + ssim_meter = AverageMeter() + output_dic = defaultdict(dict) + target_dic = defaultdict(dict) + input_dic = defaultdict(dict) + + for id, sampled_batch in enumerate(data_loader): + + if use_new_dataloader: + t1_img, t1_in = sampled_batch['t1'], sampled_batch['t1_in'] + t2_img, t2_in = sampled_batch['t2'], sampled_batch['t2_in'] + else: + t1_img, t1_in = sampled_batch['ref_image_full'], sampled_batch['ref_image_sub'] + t2_img, t2_in = sampled_batch['tag_image_full'], sampled_batch['tag_image_sub'] + + + t1_img = t1_img.to(device) + # t1_in = t1_in.to(device) + t2_img = t2_img.to(device) + t2_in = t2_in.to(device) + + mean, std = sampled_batch['t2_mean'], sampled_batch['t2_std'] + + fname = sampled_batch['fname'] + slice_num = sampled_batch['slice'] + + mean = mean.unsqueeze(1).to(device) + std = std.unsqueeze(1).to(device) + + t2_in_origin = t2_in.clone() + + # Degradation + if use_kspace: + b = 1 + t = torch.randint(num_timesteps - 1, num_timesteps, (b,), device=device).long() # t-1 + mask = kspace_masks[t] + fft, mask = apply_tofre(t2_in.clone(), mask) # t2_img + fft = fft * mask + 0.0 + t2_in = apply_to_spatial(fft) + + # Save for debug: + # put t2_in and t2_in_oringin side by side + + + t2_in_origin = t2_in.clone() + + # print("test fft/mask = ", fft.shape, fft.shape) + + + if use_in_mean_std: + t2_in = t2_in * std + mean + t2_img = t2_img * std + mean + # print("Test After restore:", t2_in.shape, t2_img.shape) + + t2_in, mean, std = normalize_instance_dim(t2_in, eps=1e-11) + t2_img = normalize(t2_img, mean=mean, stddev=std, eps=1e-11) + t2_in = t2_in.float() + t2_img = t2_img.float() + + mean = mean[0] + std = std[0] + + # print("Test Re-normalize:", t2_in.shape, t2_img.shape) + + # print("in put t2_in shape:", t2_in.shape, t2_in_origin.shape) + + while t >= 0: + if use_time_model: + outputs = model(t2_in, t1_img, t)['img_out'] + else: + outputs = model(t2_in, t1_img)['img_out'] + + if t == 0: + mask = kspace_masks[0] # last one + t2_in = outputs + + else: + k_full = kspace_masks[-1] + faded_recon_sample_fre, k_full = apply_tofre(t2_in, k_full) + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t-1] #get_kspace_kernels(t - 2).cuda() + kt = kspace_masks[t] #self.get_kspace_kernels(t - 1).cuda() # last one + k_residual = kt_sub_1 - kt + + recon_sample_fre, k_residual = apply_tofre(outputs, k_residual) + fre_amend = recon_sample_fre * k_residual + faded_recon_sample_fre = faded_recon_sample_fre + fre_amend # * (1-k_residual) + # faded_recon_sample_fre = faded_recon_sample_fre * kt_sub_1 + outputs = apply_to_spatial(faded_recon_sample_fre) + t2_in = outputs + + t = t-1 + + else: + outputs = model(t2_in, t1_img)['img_out'] + + if print_i: + t2_in_save = torch.cat([t2_in, t2_in_origin, t2_img], dim=3).cpu().numpy()[0, 0] + + t2_in_save = (t2_in_save - t2_in_save.min()) / (t2_in_save.max() - t2_in_save.min()) + # t2_in_save = np.stack([t2_in_save, t2_in_save, t2_in_save], axis=2) + # save to file + os.makedirs("./debug", exist_ok=True) + save_path = f"./debug/{use_kspace}_{fname[0]}_{slice_num[0]}.png" + plt.imsave(save_path, t2_in_save, cmap='gray') + print_i = 0 + print("print_i") + + + t2_img = t2_img.squeeze(1) * std + mean + inputs = t2_in_origin.squeeze(1) * std + mean + outputs = outputs.squeeze(1) * std + mean + + # print("output:", t2_img.shape, outputs.shape, inputs.shape) + + for i, f in enumerate(fname): + + output_dic[f][slice_num[i]] = outputs[i] + target_dic[f][slice_num[i]] = t2_img[i] + input_dic[f][slice_num[i]] = inputs[i] + + if id > 100: + break + + for name in output_dic.keys(): + f_output = torch.stack([v for _, v in output_dic[name].items()]) + f_target = torch.stack([v for _, v in target_dic[name].items()]) + + our_nmse = nmse(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_psnr = psnr(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_ssim = ssim(f_target.cpu().numpy(), f_output.cpu().numpy()) + + nmse_meter.update(our_nmse, 1) + psnr_meter.update(our_psnr, 1) + ssim_meter.update(our_ssim, 1) + + print("==> Evaluate Metric") + print("Results ----------") + print("NMSE: {:.4}".format(nmse_meter.avg)) + print("PSNR: {:.4}".format(psnr_meter.avg)) + print("SSIM: {:.4}".format(ssim_meter.avg)) + print("------------------") + model.train() + + return {'NMSE': nmse_meter.avg, 'PSNR': psnr_meter.avg, 'SSIM':ssim_meter.avg} + + + +if __name__ == "__main__": + ## make logger file + if use_kspace: + if use_new_dataloader: + snapshot_path = snapshot_path.rstrip("/") + f'_t{num_timesteps}_new_kspace/' + else: + snapshot_path = snapshot_path.rstrip("/") + f'_t{num_timesteps}/' + + + if not isinstance(args.test_tag, type(None)): + snapshot_path = snapshot_path.rstrip("/") + f'_{args.test_tag}/' + + if use_time_model: + snapshot_path = snapshot_path.rstrip("/") + '_time/' + + if not frequency_distortion: + snapshot_path = snapshot_path.rstrip("/") + 'no_distortion/' + + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + if use_time_model: + model = DiffTwoBranch(args).cuda() + else: + model = TwoBranch(args).cuda() + + # model = build_model_from_name(args).cuda() + device = torch.device('cuda') + model.to(device) + lpips_loss = LPIPS().eval().to(device) + + if len(args.gpu.split(',')) > 1: + model = nn.DataParallel(model) + # model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + # db_train = M4Raw_TrainSet(args.root_path, args.MRIDOWN, use_kspace=use_kspace) # build_dataset(args, mode='train') + # db_test = M4Raw_TestSet(args.root_path, args.MRIDOWN, use_kspace=use_kspace) # build_dataset(args, mode='val') + + db_train = M4Raw_TrainSet(args, use_kspace=use_kspace, DEBUG=DEBUG) # build_dataset(args, mode='train') + db_test = M4Raw_TestSet(args, use_kspace=use_kspace, DEBUG=DEBUG) # + + trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + if args.phase == 'train': + model.train() + + params = list(model.parameters()) + optimizer1 = optim.AdamW(params, lr=base_lr, betas=(0.9, 0.999), weight_decay=1e-4) + scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=20000, gamma=0.5) + + writer = SummaryWriter(snapshot_path + '/log') + + iter_num = 0 + max_epoch = max_iterations // len(trainloader) + 1 + + + best_status = {'NMSE': 10000000, 'PSNR': 0, 'SSIM': 0} + fft_weight=0.01 + criterion = nn.L1Loss().to(device, non_blocking=True) + freloss = Frequency_Loss().to(device, non_blocking=True) + t = 0 + + for epoch_num in tqdm(range(max_epoch), ncols=70): + time1 = time.time() + start_time = time.time() + + for i_batch, sampled_batch in enumerate(trainloader): + time2 = time.time() + + # T1 is the reference image, T2 is the target image + t1_img, t1_in = sampled_batch['t1'], sampled_batch['t1_in'] + t2_img, t2_in = sampled_batch['t2'], sampled_batch['t2_in'] + + t1_img = t1_img.to(device) + t1_in = t1_in.to(device) + t2_img = t2_img.to(device) + t2_in = t2_in.to(device) + + time3 = time.time() + + # Degradation + if use_kspace: + t2_origin = t2_in.clone() + b = t1_in.size(0) + t = torch.randint(0, num_timesteps, (b,), device=device).long() + mask = kspace_masks[t] + + fft, mask = apply_tofre(t2_in.clone(), mask) # TODO t2_img + fft = fft * mask + + # Frequency Noise + if frequency_distortion: + fft_magnitude = torch.abs(fft) # 幅度 + fft_phase = torch.angle(fft) # 相位 + + # Add noise to unmasked frequencies to maintain the stochasticity that diffusion models typically rely on. + sigma = distortion_sigma * torch.abs(torch.randn(1)).item() + noise = torch.randn_like(fft_magnitude) * sigma + noise_magnitude = noise * fft_magnitude * mask # + noise * (1 - mask) + fft_magnitude += noise_magnitude + + sigma = distortion_sigma / 2 * torch.abs(torch.randn(1)).item() + noise = torch.randn_like(fft_phase) * sigma + noise_pha = noise * fft_phase * mask # + noise * (1 - mask) + fft_phase += noise_pha + + fft = fft_magnitude * torch.exp(1j * fft_phase) + + t2_in = apply_to_spatial(fft) + + if use_in_mean_std: + mean, std = sampled_batch['t2_mean'], sampled_batch['t2_std'] + mean = mean.unsqueeze(1).unsqueeze(1).to(device) + std = std.unsqueeze(1).unsqueeze(1).to(device) + + t2_in = t2_in * std + mean + t2_img = t2_img * std + mean + # print("Train After restore:", t2_in.shape, t2_img.shape, mean.shape, std.shape) + + t2_in, mean, std = normalize_instance_dim(t2_in, eps=1e-11) + t2_img = normalize(t2_img, mean=mean, stddev=std, eps=1e-11).detach() + t2_in = t2_in.float() + t2_img = t2_img.float() + # print("Train After restore:", t2_in.shape, t2_img.shape) + + # breakpoint() + if use_time_model: + outputs = model(t2_in, t1_img, t) # ['img_out'] + else: + outputs = model(t2_in, t1_img) # ['img_out'] + + loss = criterion(outputs['img_out'], t2_img) + \ + fft_weight * freloss(outputs['img_fre'], t2_img) + \ + criterion(outputs['img_fre'], t2_img) + + # 0.01 * lpips_loss(outputs['img_out'], target).mean() + + time4 = time.time() + + + optimizer1.zero_grad() + loss.backward() + + if args.clip_grad == "True": + ### clip the gradients to a small range. + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01) + + optimizer1.step() + scheduler1.step() + + time5 = time.time() + + # summary + iter_num = iter_num + 1 + + print_iter = 100 #if not DEBUG else 5 + if iter_num % print_iter == 0: + logging.info('iteration %d [%.2f sec]: learning rate : %f loss : %f ' % (iter_num, time.time() - start_time, scheduler1.get_lr()[0], loss.item())) + if DEBUG: + break + + if iter_num % 20000 == 0: + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') + torch.save({'network': model.state_dict()}, save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + + if iter_num > max_iterations: + break + time1 = time.time() + + + ## ================ Evaluate ================ + logging.info(f'Epoch {epoch_num} Evaluation:') + # print() + eval_result = evaluate(model, testloader, device) + + if eval_result['PSNR'] > best_status['PSNR']: + best_status = {'NMSE': eval_result['NMSE'], 'PSNR': eval_result['PSNR'], 'SSIM': eval_result['SSIM']} + best_checkpoint_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + + torch.save({'network': model.state_dict()}, best_checkpoint_path) + print('New Best Network saved:', best_checkpoint_path) + + logging.info(f"average MSE: {eval_result['NMSE']} average PSNR: {eval_result['PSNR']} average SSIM: {eval_result['SSIM']}") + print("snapshot_path=", snapshot_path) + + if iter_num > max_iterations: + break + + print(best_status) + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(max_iterations) + '.pth') + torch.save({'network': model.state_dict()}, + save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + writer.close() diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__init__.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__pycache__/__init__.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f041fc1c0de080e82a75f98c690449d183b02a9 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__pycache__/lpips.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__pycache__/lpips.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4608e0d4c1f9caf5ea751ad549419dc35c5022e9 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__pycache__/lpips.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__pycache__/metric.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__pycache__/metric.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee9b2a52af2da476cef6074672d40dcb995f5565 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__pycache__/metric.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__pycache__/option.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__pycache__/option.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..314f0afbf2aa01f39fa2cc7a06a51d264c6786e5 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__pycache__/option.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__pycache__/utils.cpython-310.pyc b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa8acca2463c08ca043938d88b518cd173efb296 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/__pycache__/utils.cpython-310.pyc differ diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/cache/vgg.pth b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/cache/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..f57dcf5cc764d61c8a460365847fb2137ff0a62d --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/cache/vgg.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 +size 7289 diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/lpips.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a30b875fb4aa39ccd8419759d2f841d62bbad6 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/lpips.py @@ -0,0 +1,184 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" + +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + + +from collections import namedtuple +from torchvision import models +import torch.nn as nn +import torch +from tqdm import tqdm +import requests +import os +import hashlib +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format( + name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + self.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name is not "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + model.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + input = input.float() + target = target.float() + + in0_input, in1_input = (self.scaling_layer( + input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor( + outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor( + [-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor( + [.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, + padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, + h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/metric.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..53ddb27a96bab67975beef06ca6819e628208153 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/metric.py @@ -0,0 +1,51 @@ + +import numpy as np +from skimage.metrics import peak_signal_noise_ratio, structural_similarity + +def nmse(gt, pred): + """Compute Normalized Mean Squared Error (NMSE)""" + return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2 + + +def psnr(gt, pred): + """Compute Peak Signal to Noise Ratio metric (PSNR)""" + return peak_signal_noise_ratio(gt, pred, data_range=gt.max()) + + +def ssim(gt, pred, maxval=None): + """Compute Structural Similarity Index Metric (SSIM)""" + maxval = gt.max() if maxval is None else maxval + + ssim = 0 + for slice_num in range(gt.shape[0]): + ssim = ssim + structural_similarity( + gt[slice_num], pred[slice_num], data_range=maxval + ) + + ssim = ssim / gt.shape[0] + + return ssim + + +class AverageMeter(object): + """Computes and stores the average and current value. + + Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 + """ + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self.score = [] + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + self.score.append(val) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/option.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/option.py new file mode 100644 index 0000000000000000000000000000000000000000..2b21cde0c6ac461423bed37efc8290898713ab9a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/option.py @@ -0,0 +1,72 @@ +import argparse + +parser = argparse.ArgumentParser(description='MRI recon') +parser.add_argument('--root_path', type=str, default='/home/xiaohan/datasets/BRATS_dataset/BRATS_2020_images/selected_images/') +parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') +parser.add_argument('--low_field_SNR', type=int, default=0, help='SNR of the simulated low-field image') +parser.add_argument('--phase', type=str, default='train', help='Name of phase') +parser.add_argument('--gpu', type=str, default='0', help='GPU to use') +parser.add_argument('--exp', type=str, default='msl_model', help='model_name') +parser.add_argument('--max_iterations', type=int, default=100000, help='maximum epoch number to train') +parser.add_argument('--batch_size', type=int, default=8, help='batch_size per gpu') +parser.add_argument('--base_lr', type=float, default=0.0002, help='maximum epoch numaber to train') +parser.add_argument('--seed', type=int, default=1337, help='random seed') +parser.add_argument('--resume', type=str, default=None, help='resume') +parser.add_argument('--relation_consistency', type=str, default='False', help='regularize the consistency of feature relation') +parser.add_argument('--clip_grad', type=str, default='True', help='clip gradient of the network parameters') + + +parser.add_argument('--norm', type=str, default='False', help='Norm Layer between UNet and Transformer') +parser.add_argument('--input_normalize', type=str, default='mean_std', help='choose from [min_max, mean_std, divide]') + +parser.add_argument("--dist_url", default="63654") + +parser.add_argument('--scale', type=int, default=8, + help='super resolution scale') +parser.add_argument('--base_num_every_group', type=int, default=2, + help='super resolution scale') +parser.add_argument('--snapshot_path', default="None", type=str) + + +parser.add_argument('--rgb_range', type=int, default=255, + help='maximum value of RGB') +parser.add_argument('--n_colors', type=int, default=3, + help='number of color channels to use') +parser.add_argument('--augment', action='store_true', + help='use data augmentation') +parser.add_argument('--fftloss', action='store_true', + help='use data augmentation') +parser.add_argument('--fftd', action='store_true', + help='use data augmentation') +parser.add_argument('--fftd_weight', type=float, default=0.1, + help='use data augmentation') +parser.add_argument('--fft_weight', type=float, default=0.01) + + +# Model specifications +parser.add_argument('--model', type=str, default='MYNET') +parser.add_argument('--act', type=str, default='PReLU') +parser.add_argument('--data_range', type=float, default=1) +parser.add_argument('--num_channels', type=int, default=1) +parser.add_argument('--num_features', type=int, default=64) + +parser.add_argument('--n_feats', type=int, default=64, + help='number of feature maps') +parser.add_argument('--res_scale', type=float, default=0.2, + help='residual scaling') + +parser.add_argument('--MASKTYPE', type=str, default='random') # "random" or "equispaced" +parser.add_argument('--CENTER_FRACTIONS', nargs='+', type=float) +parser.add_argument('--ACCELERATIONS', nargs='+', type=int) + +parser.add_argument('--num_timesteps', type=int, default=5) +parser.add_argument('--image_size', type=int, default=240) +parser.add_argument('--distortion_sigma', type=float, default=10/255) +parser.add_argument('--DEBUG', action='store_true') +parser.add_argument('--use_kspace', action='store_true') +parser.add_argument('--use_time_model', action='store_true') + +parser.add_argument('--test_tag', default=None) +parser.add_argument('--test_sample', default="Ksample", help="Ksample | ColdDiffusion | DDPM") + +args = parser.parse_args() diff --git a/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/utils.py b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a8c9361564ec52c1dab3fb970616c8edc0893c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/FSMNet/utils/utils.py @@ -0,0 +1,96 @@ +import torch +from torch import nn +import numpy as np + +def cc(img1, img2): + eps = torch.finfo(torch.float32).eps + """Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.].""" + N, C, _, _ = img1.shape + img1 = img1.reshape(N, C, -1) + img2 = img2.reshape(N, C, -1) + img1 = img1 - img1.mean(dim=-1, keepdim=True) + img2 = img2 - img2.mean(dim=-1, keepdim=True) + cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum( + img1 **2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1))) + cc = torch.clamp(cc, -1., 1.) + return cc.mean() + + +def gradient_calllback(network): + for name, param in network.named_parameters(): + if param.grad is not None: + # print("Gradient of {}: {}".format(name, param.grad.abs().mean())) + + if param.grad.abs().mean() == 0: + print("Gradient of {} is 0".format(name)) + + else: + print("Gradient of {} is None".format(name)) + + +def bright(x, a,b): + # input datatype np.uint8 + x = np.array(x, dtype='float') + x = x/(b-a) - 255*a/(b-a) + x[x>255.0] = 255.0 + x[x<0.0] = 0.0 + x = x.astype(np.uint8) + return x + +def trunc(x): + # input datatype float + x[x>255.0] = 255.0 + x[x<0.0] = 0.0 + return x + + + + + +def gradient_calllback(network): + for name, param in network.named_parameters(): + if param.grad is not None: + if param.grad.abs().mean() == 0: + print("Gradient of {} is 0".format(name)) + + else: + print("Gradient of {} is None".format(name)) + +class Frequency_Loss(nn.Module): + def __init__(self): + super(Frequency_Loss, self).__init__() + self.cri = nn.L1Loss() + self.cri_sum = nn.L1Loss(reduction="sum") + + def forward(self, x, y, mask=None): + x = torch.fft.fftshift(torch.fft.fft2(x)) # rfft2 + y = torch.fft.fftshift(torch.fft.fft2(y)) + + + + # def apply_tofre(x_start, mask): + # # B, C, H, W = x_start.shape + # kspace = fftshift(fft2(x_start, norm=None, dim=(-2, -1)), dim=(-2, -1)) # Default: all dimensions + # mask = mask.to(kspace.device) + # return kspace, mask + + x_mag = torch.abs(x) + y_mag = torch.abs(y) + x_ang = torch.angle(x) + y_ang = torch.angle(y) + if isinstance(mask, type(None)): + return self.cri(x_mag,y_mag) + self.cri(x_ang, y_ang) + + k = (1 - mask.to(x.device)).detach() + # W = x.shape[-1] + # k = k[..., :W // 2 + 1] + k_total = torch.sum(k) + + x_mag = x_mag * k + y_mag = y_mag * k + x_ang = x_ang * k + y_ang = y_ang * k + + # Compute L1 loss between magnitudes + return self.cri_sum(x_mag, y_mag) / k_total + self.cri_sum(x_ang, y_ang) / k_total + diff --git a/MRI_recon/code/Frequency-Diffusion/README.md b/MRI_recon/code/Frequency-Diffusion/README.md new file mode 100644 index 0000000000000000000000000000000000000000..915cfe1e79d4db2f1439101a0da25dcd1a0a461b --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/README.md @@ -0,0 +1,19 @@ +# Frequency-Diffusion + + +Run Knee FastMRI dataset, the test code is also in the end of the file +```bash + +cd FSMNet +bash bash/fastmri.sh + +``` + +Run Brain m4raw dataset, the test code is also in the end of the file +```bash + +cd FSMNet +bash bash/m4raw.sh + +``` + diff --git a/MRI_recon/code/Frequency-Diffusion/__init__.py b/MRI_recon/code/Frequency-Diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/code/Frequency-Diffusion/bash/adden/brain.sh b/MRI_recon/code/Frequency-Diffusion/bash/adden/brain.sh new file mode 100644 index 0000000000000000000000000000000000000000..c0e05b7a031a05ea07cebbb21270234cb9ffa5f7 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/bash/adden/brain.sh @@ -0,0 +1,82 @@ +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion +# cd STEP1.AutoencoderModel2D +#git stash +git pull + +datapath=/home/hao/data/medical/Brain/ +# /gamedrive/Datasets/medical/Brain/ + +dataset=Brain +domain=BraTS-GLI-T1C # T1C +aux_modality=T1N # T1C, T1N, T2W, T2F +num_channels=1 + + +# T1: T1-weighted MRI; T1c: gadolinium-contrast-enhanced T1-weighted MRI; + + +diffusion_type=twobranch_kspace # Easy NaN + + +time_step=30 +image_size=240 #128 +sampling_routine=x0_step_down_fre # x0_step_down_fre # x0_step_down_fre # default | x0_step_down | x0_step_down_fre +loss_type=l1 # l2 1 # l2 | l1 | l2_l1, l1 is better + + +tag=new_norm #add_blur_transformer # x0_step_down | x0_step_down_fre + +deviceid=1 # specify the GPU ids +# fre_before_attn + l1 +train_bs=2 # 4 | 32 | 24 | 36 + + +save_folder=./results/${diffusion_type}_${sampling_routine} + + +datapath=/home/hao/data/medical/Brain/ +datapath=/gamedrive/Datasets/medical/Brain/brats/Processed/ + + +mode=train + +# Train +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --domain $domain --aux_modality $aux_modality \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --loss_type $loss_type --debug --mode $mode # --debug, --discrete + +# FSM Brain +/gamedrive/Datasets/medical/FrequencyDiffusion/image_100patients_4X + +BraTS20_Training_099_99_t1.png +BraTS20_Training_099_99_t2.png + +mode=test +checkpoint=results/52_twobranch_kspace_x0_step_down_fre_new_norm/model.pt + + +deviceid=1 # deviceid=1 + + +# test +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --domain $domain --aux_modality $aux_modality \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --load_path $checkpoint --debug --mode $mode # --debug, --discrete + + + +# Knee +/gamedrive/Datasets/medical/Knee/fastMRI/ diff --git a/MRI_recon/code/Frequency-Diffusion/bash/adden/fsm_brain.sh b/MRI_recon/code/Frequency-Diffusion/bash/adden/fsm_brain.sh new file mode 100644 index 0000000000000000000000000000000000000000..6382b1e4b75426e366d3c62cd25721c247a0d854 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/bash/adden/fsm_brain.sh @@ -0,0 +1,69 @@ +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion +# cd STEP1.AutoencoderModel2D +#git stash +git pull + + +dataset=fsm_brain +num_channels=1 +diffusion_type=twobranch_kspace + + +datapath=/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_4X + + +time_step=25 +image_size=240 #128 +sampling_routine=x0_step_down_fre # default | x0_step_down | x0_step_down_fre +loss_type=l1 # l2 | l1 | l2_l1, l1 is better + + +tag=adden_brain #add_blur_transformer # x0_step_down | x0_step_down_fre + +deviceid=1 # specify the GPU ids +# fre_before_attn + l1 +train_bs=4 # 4 | 32 | 24 | 36 +accelerate_factor=8 + +save_folder=./results/${diffusion_type}_${sampling_routine} + + +normalizer="mean_std" +mode=train + +# Train +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --loss_type $loss_type --accelerate_factor $accelerate_factor\ + --mode $mode --normalizer $normalizer --debug # --debug, --discrete + +# FSM Brain + +mode=test +checkpoint=results/52_twobranch_kspace_x0_step_down_fre_new_norm/model.pt + + +deviceid=1 # deviceid=1 + + +# test +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --domain $domain --aux_modality $aux_modality \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --load_path $checkpoint --debug --mode $mode # --debug, --discrete + + + +# Knee +/gamedrive/Datasets/medical/Knee/fastMRI/ diff --git a/MRI_recon/code/Frequency-Diffusion/bash/adden/knee.sh b/MRI_recon/code/Frequency-Diffusion/bash/adden/knee.sh new file mode 100644 index 0000000000000000000000000000000000000000..27f30c851971f46d1a241a63da7c9583e0ccd5a7 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/bash/adden/knee.sh @@ -0,0 +1,68 @@ +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion +# cd STEP1.AutoencoderModel2D +#git stash +git pull + + +dataset=fsm_brain +num_channels=1 +diffusion_type=twobranch_kspace + + +datapath=/gamedrive/Datasets/medical/FrequencyDiffusion/singlecoil_train_selected + + +time_step=25 +image_size=320 #128 +sampling_routine=x0_step_down_fre # default | x0_step_down | x0_step_down_fre +loss_type=l1 # l2 | l1 | l2_l1, l1 is better + + +tag=adden_brain #add_blur_transformer # x0_step_down | x0_step_down_fre + +deviceid=1 # specify the GPU ids +train_bs=4 # 4 | 32 | 24 | 36 +accelerate_factor=8 + +save_folder=./results/${diffusion_type}_${sampling_routine} + + +normalizer="mean_std" +mode=train + +# Train +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --loss_type $loss_type --accelerate_factor $accelerate_factor\ + --mode $mode --normalizer $normalizer --debug # --debug, --discrete + +# FSM Brain + +mode=test +checkpoint=results/52_twobranch_kspace_x0_step_down_fre_new_norm/model.pt + + +deviceid=1 # deviceid=1 + + +# test +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --domain $domain --aux_modality $aux_modality \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --load_path $checkpoint --debug --mode $mode # --debug, --discrete + + + +# Knee +/gamedrive/Datasets/medical/Knee/fastMRI/ diff --git a/MRI_recon/code/Frequency-Diffusion/bash/bask/brain.sh b/MRI_recon/code/Frequency-Diffusion/bash/bask/brain.sh new file mode 100644 index 0000000000000000000000000000000000000000..a085b08e428a3b1493ec5f890cc13a6d58a3a665 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/bash/bask/brain.sh @@ -0,0 +1,82 @@ +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion +# cd STEP1.AutoencoderModel2D +#git stash +git pull + +datapath=/home/hao/data/medical/Brain/ +# /gamedrive/Datasets/medical/Brain/ + +dataset=Brain +domain=BraTS-GLI-T1C # T1C +aux_modality=T1N # T1C, T1N, T2W, T2F +num_channels=1 + + +# T1: T1-weighted MRI; T1c: gadolinium-contrast-enhanced T1-weighted MRI; + + +diffusion_type=twobranch_kspace # Easy NaN + + +time_step=30 +image_size=480 #128 +sampling_routine=x0_step_down_fre # x0_step_down_fre # x0_step_down_fre # default | x0_step_down | x0_step_down_fre +loss_type=l1 # l2 1 # l2 | l1 | l2_l1, l1 is better + + +tag=new_norm #add_blur_transformer # x0_step_down | x0_step_down_fre + +deviceid=1 # specify the GPU ids +# fre_before_attn + l1 +train_bs=2 # 4 | 32 | 24 | 36 + + +save_folder=./results/${diffusion_type}_${sampling_routine} + + +datapath=/home/hao/data/medical/Brain/ +datapath=/gamedrive/Datasets/medical/Brain/brats/Processed/ + + +mode=train + +# Train +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --domain $domain --aux_modality $aux_modality \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --loss_type $loss_type --debug --mode $mode # --debug, --discrete + +# FSM Brain +/gamedrive/Datasets/medical/FrequencyDiffusion/image_100patients_4X + +BraTS20_Training_099_99_t1.png +BraTS20_Training_099_99_t2.png + +mode=test +checkpoint=results/52_twobranch_kspace_x0_step_down_fre_new_norm/model.pt + + +deviceid=1 # deviceid=1 + + +# test +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --domain $domain --aux_modality $aux_modality \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --load_path $checkpoint --debug --mode $mode # --debug, --discrete + + + +# Knee +/gamedrive/Datasets/medical/Knee/fastMRI/ diff --git a/MRI_recon/code/Frequency-Diffusion/bash/bask/fsm_brain.sh b/MRI_recon/code/Frequency-Diffusion/bash/bask/fsm_brain.sh new file mode 100644 index 0000000000000000000000000000000000000000..f285065b12014ca64037683ec4afd95db79c85a8 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/bash/bask/fsm_brain.sh @@ -0,0 +1,68 @@ +mamba activate diffmri +cd /bask/projects/j/jiaoj-rep-learn/Hao/repo/Frequency-Diffusion +# cd STEP1.AutoencoderModel2D +#git stash +git pull + + +dataset=fsm_brain +num_channels=1 +diffusion_type=twobranch_kspace # Easy NaN +datapath=/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee/image_100patients_4X/ + + +time_step=30 +image_size=240 #128 +sampling_routine=x0_step_down_fre # default | x0_step_down | x0_step_down_fre +loss_type=l1 # l2 | l1 | l2_l1, l1 is better + + +tag=fsm_brain #add_blur_transformer # x0_step_down | x0_step_down_fre + +deviceid=0 # specify the GPU ids +# fre_before_attn + l1 +train_bs=4 # 4 | 32 | 24 | 36 + + +save_folder=./results/${diffusion_type}_${sampling_routine} +normalizer="mean_std" + +mode=train +example_frequency_img="/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee/image_100patients_4X/BraTS20_Training_036_90_t2_4X_undermri.png" # some example img +example_frequency_img="" + +# Train +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --loss_type $loss_type --normalizer $normalizer \ + --example_frequency_img $example_frequency_img --debug --mode $mode # --debug, --discrete + +# FSM Brain + +mode=test +checkpoint=results/52_twobranch_kspace_x0_step_down_fre_new_norm/model.pt + + +deviceid=1 # deviceid=1 + + +# test +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --domain $domain --aux_modality $aux_modality \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --load_path $checkpoint --debug --mode $mode # --debug, --discrete + + + +# Knee +/gamedrive/Datasets/medical/Knee/fastMRI/ diff --git a/MRI_recon/code/Frequency-Diffusion/bash/bask/knee.sh b/MRI_recon/code/Frequency-Diffusion/bash/bask/knee.sh new file mode 100644 index 0000000000000000000000000000000000000000..fd3c61236ab8da5543f76676e67c2a921e6f1d3d --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/bash/bask/knee.sh @@ -0,0 +1,64 @@ +mamba activate diffmri +cd /bask/projects/j/jiaoj-rep-learn/Hao/repo/Frequency-Diffusion +# cd STEP1.AutoencoderModel2D +#git stash +git pull + + +dataset=fsm_knee +num_channels=1 +diffusion_type=twobranch_kspace # Easy NaN +datapath=/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee/singlecoil_train_selected + + +time_step=30 +image_size=320 #320 +sampling_routine=x0_step_down_fre # default | x0_step_down | x0_step_down_fre +loss_type=l1 # l2 | l1 | l2_l1, l1 is better + + +tag=fsm_knee #add_blur_transformer # x0_step_down | x0_step_down_fre + +deviceid=0 # specify the GPU ids +train_bs=4 # 4 | 32 | 24 | 36 + +normalizer="mean_std" +save_folder=./results/${diffusion_type}_${sampling_routine} + + +mode=train + +# Train +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --loss_type $loss_type --normalizer $normalizer --debug --mode $mode # --debug, --discrete + + + +mode=test +checkpoint=results/71_twobranch_kspace_x0_step_down_fre_new_loss/model.pt + + +#deviceid=1 # deviceid=1 + + +# test +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --load_path $checkpoint --debug --mode $mode + # --debug, --discrete + + + +# Knee +/gamedrive/Datasets/medical/Knee/fastMRI/ diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/BRATS_dataloader.py b/MRI_recon/code/Frequency-Diffusion/dataset/BRATS_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..bc433096d6058d9c5e7a259e56a6af2da385737c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/BRATS_dataloader.py @@ -0,0 +1,419 @@ +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import cv2 +import os +import torch +from torch.utils.data import Dataset +from torchvision import transforms + + +from dataset.m4_utils.transform_albu import get_albu_transforms + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', image_size=(128,128), MRIDOWN='4X', \ + SNR=15, transform=None, input_normalize=None, debug=False): + + super().__init__() + self._base_dir = base_dir + '/' + # self._MRIDOWN = MRIDOWN + + + self.transforms = get_albu_transforms(split, image_size) + self.kspace_refine = "False" + self._MRIDOWN = "4X" + + self.im_ids = [] + self.t2_images = [] + self.t1_undermri_images, self.t2_undermri_images = [], [] + self.t1_krecon_images, self.t2_krecon_images = [], [] + self.splits_path = base_dir.replace("image_100patients_4X", "cv_splits_100patients") + + if split=='train': + self.train_file = self.splits_path + '/train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + '/test_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + if debug: + self.t1_images = self.t1_images[:10] + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + + if SNR == 0: + t1_under_path = image_path + + if self.kspace_refine == "False": + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + elif self.kspace_refine == "True": + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_krecon') + + if self.kspace_refine == "False": + t1_krecon_path = image_path + t2_krecon_path = image_path + + # if SNR == 0: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_undermri') + t1_under_path = image_path + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + + else: + t1_under_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB') + + t1_krecon_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_krecon_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB') + + + self.t2_images.append(t2_path) + self.t1_undermri_images.append(t1_under_path) + self.t2_undermri_images.append(t2_under_path) + + self.t1_krecon_images.append(t1_krecon_path) + self.t2_krecon_images.append(t2_krecon_path) + + self.transform = transform + self.input_normalize = input_normalize + + assert (len(self.t1_images) == len(self.t2_images)) + assert (len(self.t1_images) == len(self.t1_undermri_images)) + assert (len(self.t1_images) == len(self.t2_undermri_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + def update_chunk(self): + pass + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + t1 = np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0 + t2 = np.array(Image.open(self._base_dir + self.t2_images[index])) / 255.0 + + t1_in = np.array(Image.open(self._base_dir + self.t1_undermri_images[index]))/255.0 + t1_krecon = np.array(Image.open(self._base_dir + self.t1_krecon_images[index]))/255.0 + t2_in = np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0 + t2_krecon = np.array(Image.open(self._base_dir + self.t2_krecon_images[index]))/255.0 + + t1 = np.asarray(t1, np.float32) + t2 = np.asarray(t2, np.float32) + + if self.input_normalize == "mean_std": + t1_in, t1_mean, t1_std = normalize_instance(t1_in, eps=1e-11) + t1 = normalize(t1, t1_mean, t1_std, eps=1e-11) + t2_in, t2_mean, t2_std = normalize_instance(t2_in, eps=1e-11) + t2 = normalize(t2, t2_mean, t2_std, eps=1e-11) + + t1_krecon = normalize(t1_krecon, t1_mean, t1_std, eps=1e-11) + t2_krecon = normalize(t2_krecon, t2_mean, t2_std, eps=1e-11) + + ### clamp input to ensure training stability. + t1_in = np.clip(t1_in, -6, 6) + t1 = np.clip(t1, -6, 6) + t2_in = np.clip(t2_in, -6, 6) + t2 = np.clip(t2, -6, 6) + + t1_krecon = np.clip(t1_krecon, -6, 6) + t2_krecon = np.clip(t2_krecon, -6, 6) + + sample_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + elif self.input_normalize == "min_max": + t1_in = (t1_in - t1_in.min())/(t1_in.max() - t1_in.min()) + t1 = (t1 - t1.min())/(t1.max() - t1.min()) + t2_in = (t2_in - t2_in.min())/(t2_in.max() - t2_in.min()) + t2 = (t2 - t2.min())/(t2.max() - t2.min()) + sample_stats = 0 + + t1_mean = 0 + t1_std = 1 + t2_mean = 0 + t2_std = 1 + + elif self.input_normalize == "divide": + sample_stats = 0 + + + sample = {'image_in': t1_in, + 'image': t1, + 'image_krecon': t1_krecon, + 'target_in': t2_in, + 'target': t2, + 'target_krecon': t2_krecon} + + # print("images shape:", t1_in.shape, t1.shape, t2_in.shape, t2.shape) + + # return sample, sample_stats + + t1_main = False + # t1 support t2 accelerate + + if t1_main: + img = t1 + img_mean = t1_mean + img_std = t1_std + + aux = t2 + aux_mean = t2_mean + aux_std = t2_std + else: + img = t2 + img_mean = t2_mean + img_std = t2_std + + aux = t1 + aux_mean = t1_mean + aux_std = t1_std + + + # print("img shape:", img.shape, aux.shape, img.max()) # 240, 240 + + data_dict = self.transforms(image=img, image2=aux) + img = data_dict['image'] + aux = data_dict['image2'] + + img = np.asarray(np.expand_dims(img, axis=0), np.float32) + aux = np.asarray(np.expand_dims(aux, axis=0), np.float32) + + data = {"img": img, "aux": aux, + "img_mean": np.float32(img_mean), "img_std": np.float32(img_std), + "aux_mean": np.float32(aux_mean), "aux_std": np.float32(aux_std), + } + + return data + + + +def add_gaussian_noise(img, mean=0, std=1): + noise = std * torch.randn_like(img) + mean + noisy_img = img + noise + return torch.clamp(noisy_img, 0, 1) + + + +class AddNoise(object): + def __call__(self, sample): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + add_gauss_noise = transforms.GaussianBlur(kernel_size=5) + add_poiss_noise = transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)) + + add_noise = transforms.RandomApply([add_gauss_noise, add_poiss_noise], p=0.5) + + img_in = add_noise(img_in) + target_in = add_noise(target_in) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + + return sample + + + + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 256, 256 + crop_size = 240 + pad_size = (256-240)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_krecon = sample['image_krecon'] + target_krecon = sample['target_krecon'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + img_krecon = np.pad(img_krecon, pad_size, mode='reflect') + target_krecon = np.pad(target_krecon, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + # print("img_in:", img_in.shape) + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + img_krecon = img_krecon[ww:ww+crop_size, hh:hh+crop_size] + target_krecon = target_krecon[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'image_krecon': img_krecon, \ + 'target_in': target_in, 'target': target, 'target_krecon': target_krecon} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + + +class RandomFlip(object): + def __call__(self, sample): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + # horizontal flip + if random.random() < 0.5: + img_in = cv2.flip(img_in, 1) + img = cv2.flip(img, 1) + target_in = cv2.flip(target_in, 1) + target = cv2.flip(target, 1) + + # vertical flip + if random.random() < 0.5: + img_in = cv2.flip(img_in, 0) + img = cv2.flip(img, 0) + target_in = cv2.flip(target_in, 0) + target = cv2.flip(target, 0) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + + + +class RandomRotate(object): + def __call__(self, sample, center=None, scale=1.0): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + degrees = [0, 90, 180, 270] + angle = random.choice(degrees) + + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + + img_in = cv2.warpAffine(img_in, matrix, (w, h)) + img = cv2.warpAffine(img, matrix, (w, h)) + target_in = cv2.warpAffine(target_in, matrix, (w, h)) + target = cv2.warpAffine(target, matrix, (w, h)) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + + image_krecon = sample['image_krecon'][:, :, None].transpose((2, 0, 1)) + target_krecon = sample['target_krecon'][:, :, None].transpose((2, 0, 1)) + + # print("img_in before_numpy range:", img_in.max(), img_in.min()) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + image_krecon = torch.from_numpy(image_krecon).float() + target_krecon = torch.from_numpy(target_krecon).float() + # print("img_in range:", img_in.max(), img_in.min()) + + return {'image_in': img_in, + 'image': img, + 'target_in': target_in, + 'target': target, + 'image_krecon': image_krecon, + 'target_krecon': target_krecon} diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/__init__.py b/MRI_recon/code/Frequency-Diffusion/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf6736b58fe7f6c85492b3f8a78ad34bb1c49520 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/__init__.py @@ -0,0 +1,3 @@ +from .brain import BrainDataset +from .celeb import Dataset, Dataset_Aug1 + diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/basic.py b/MRI_recon/code/Frequency-Diffusion/dataset/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..da03388bb83aa0a49793ebc3cd7be4302f0d4fc5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/basic.py @@ -0,0 +1,460 @@ + +# Dataloader for abdominal images +import glob +import numpy as np +from .m4_utils import niftiio as nio +from .m4_utils import transform_utils as trans +from .m4_utils.abd_dataset_utils import get_normalize_op +from .m4_utils.transform_albu import get_albu_transforms, get_resize_transforms +import copy +import random, cv2, os +import torch.utils.data as torch_data +import math +import itertools +from pdb import set_trace +from multiprocessing import Process +import albumentations as A +from tqdm import tqdm + + +def get_basedir(data_dir): + return os.path.join(data_dir, "Abdominal") + + +class BasicDataset(torch_data.Dataset): + def __init__(self, fineSize, mode, transforms, base_dir, domains: list, aux_modality, pseudo = False, + idx_pct = [0.7, 0.1, 0.2], tile_z_dim = 3, extern_norm_fn = None, + LABEL_NAME=["bg", "fore"], debug=False, nclass=4, num_channels=3, + filter_non_labeled=False, use_diff_axis_view=False, chunksize=200): + """ + Args: + mode: 'train', 'val', 'test', 'test_all' + transforms: naive data augmentations used by default. Photometric transformations slightly better than those configured by Zhang et al. (bigaug) + idx_pct: train-val-test split for source domain + extern_norm_fn: feeding data normalization functions from external, only concerns CT-MR cross domain scenario + """ + super(BasicDataset, self).__init__() + + self.fineSize = fineSize + self.transforms = transforms + self.nclass = nclass + self.debug = debug + self.is_train = True if mode == 'train' else False + self.phase = mode + self.domains = domains + self.num_channels = num_channels + + # Modality + self.main_modality = domains[-1].split("-")[-1] + self.aux_modality = aux_modality.upper() + + print(f"=== Donmain: {domains}, Main modality: {self.main_modality}, Aux modality: {self.aux_modality}") + + self.pseudo = pseudo + self.all_label_names = LABEL_NAME + self.nclass = len(LABEL_NAME) + self.tile_z_dim = tile_z_dim + self._base_dir = base_dir + self.idx_pct = idx_pct + # self.albu_transform = get_albu_transforms((fineSize, fineSize)) + self.test_resizer = get_resize_transforms(fineSize) + self.fake_interpolate = True # True + self.use_diff_axis_view = use_diff_axis_view + self.filter_non_labeled = filter_non_labeled + self.input_window = 1 + + self.resizer = A.Compose([ + A.Resize(fineSize[0], fineSize[1], interpolation=cv2.INTER_NEAREST) + ], p=1.0, additional_targets={'image2': 'image', "mask2": "mask"}) + + self.img_pids = {} + for _domain in self.domains: # load file names + if "BraTS" in _domain: + self.img_pids[_domain] = sorted([ fid.split("-")[-2] for fid in + glob.glob(self._base_dir + "/" + _domain + "/img/*.nii.gz") ], + key = lambda x: int(x)) + + else: + self.img_pids[_domain] = sorted([fid.split("_")[-1].split(".nii.gz")[0] for fid in + glob.glob(self._base_dir + "/" + _domain + "/img/*.nii.gz")], + key=lambda x: int(x)) + + self.scan_ids = self.__get_scanids(mode, idx_pct) # train val test split in terms of patient ids + try: + print(f'For {self.phase} on {[_dm for _dm in self.domains]} using scan ids len = ' + \ + f'{[len(self.scan_ids[_dm]) for _dm in self.scan_ids.keys()]}') + except: + print("Errors of self.scan_ids") + print(self.scan_ids) + + + self.info_by_scan = None + self.sample_list = self.__search_samples(self.scan_ids) # image files names according to self.scan_ids + if self.is_train: + + self.pid_curr_load = self.scan_ids + elif mode == 'val': + self.pid_curr_load = self.scan_ids + elif mode == 'test': # Source domain test + self.pid_curr_load = self.scan_ids + elif mode == 'test_all': + # Choose this when being used as a target domain testing set. Liu et al. + self.pid_curr_load = self.scan_ids + + if extern_norm_fn is None: + self.normalize_op = get_normalize_op(self.domains[0], [itm['img_fid'] for _, itm in + self.sample_list[self.domains[0]].items() ]) + print(f'{self.phase}_{self.domains[0]}: Using fold data statistics for normalization') + + else: + # assert len(self.domains) == 1, 'for now we only support one normalization function for the entire set' + self.normalize_op = extern_norm_fn + + + # load to memory + # self.sample_list All + self.actual_dataset = None + self.chunksize = chunksize if not debug else 3 + + self.chunk_id = 0 + self.chunk_pool, self.current_chunk = {}, {} + for _domain, item in self.sample_list.items(): + self.chunk_pool[_domain] = list(item.keys()) + + chunk, status = self.next_chunk(self.sample_list) + self.actual_dataset = self.__read_dataset(chunk, status) + self.size = len(self.actual_dataset) # 2D + + print("----- Set up dataset for", self.phase, "with chunksize=", chunksize) + + def update_chunk(self): + chunk, status = self.next_chunk(self.sample_list) + self.actual_dataset = self.__read_dataset(chunk, status) + + def __get_scanids(self, mode, idx_pct): + """ + index by domains given that we might need to load multi-domain data + idx_pct: [0.7 0.1 0.2] for train val test. with order te val tr + """ + tr_ids = {} + val_ids = {} + te_ids = {} + te_all_ids = {} + + for _domain in self.domains: + dset_size = len(self.img_pids[_domain]) + tr_size = round(dset_size * idx_pct[0]) + val_size = math.floor(dset_size * idx_pct[1]) + te_size = dset_size - tr_size - val_size + # print('te_size = ', te_size) + + te_ids[_domain] = self.img_pids[_domain][: te_size] + val_ids[_domain] = self.img_pids[_domain][te_size: te_size + val_size] + tr_ids[_domain] = self.img_pids[_domain][te_size + val_size: ] + te_all_ids[_domain] = list(itertools.chain(tr_ids[_domain], te_ids[_domain], val_ids[_domain] )) + + print(" self.phase = ", self.phase) + if self.phase == 'train': + return tr_ids + elif self.phase == 'val': + return val_ids + elif self.phase == 'test': + return te_ids + elif self.phase == 'test_all': + return te_all_ids + + def __search_samples(self, scan_ids): + """search for filenames for images and masks + """ + out_list = {} + for _domain, id_list in scan_ids.items(): + domain_dir = os.path.join(self._base_dir, _domain) + print("=== reading domains from:", domain_dir) + out_list[_domain] = {} + for curr_id in id_list: + curr_dict = {} + if "BraTS" in _domain: + + _img_fid = os.path.join(domain_dir, 'img', f'{_domain[:-4]}-{curr_id}-000.nii.gz') + if not self.pseudo: + _lb_fid = os.path.join(domain_dir, 'seg', f'{_domain[:-4]}-{curr_id}-000.nii.gz') + else: + _lb_fid = os.path.join(domain_dir, 'seg', f'{_domain[:-4]}-{curr_id}-000.nii.gz.npy') # npy + + _aux_fid = _img_fid.replace(self.main_modality, self.aux_modality) + + + + curr_dict["img_fid"] = _img_fid + curr_dict["lbs_fid"] = _lb_fid + curr_dict["aux_fid"] = _aux_fid + out_list[_domain][str(curr_id)] = curr_dict + + print("=== search sample num:", len(out_list)) + return out_list + + + def filter_with_label(self, img, lb, aux): + # H, W, C, filter zero + if self.phase == "train": + + filter = np.any(np.any(img, axis=0), axis=0) + img, lb, aux = img[..., filter], lb[..., filter], aux[..., filter] + + + if self.filter_non_labeled: + + if self.dataset_key == "knee": + filter2 = np.any(np.any(lb == 2, axis=0), axis=0) + filter4 = np.any(np.any(lb == 4, axis=0), axis=0) + filter = filter2 + filter4 + else: + filter = np.any(np.any(lb, axis=0), axis=0) + + filter_right = np.roll(filter, 3) + filter_left = np.roll(filter, -3) + filter = filter + filter_right + filter_left + filter = filter > 0 + + # HWC + img, lb, aux = img[..., filter], lb[..., filter], aux[..., filter] + + return img, lb, aux + + def __read_dataset(self, chunk, status): + """ + Read the dataset into memory + """ + + out_list = [] + self.info_by_scan = {} # meta data of each scan + glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset + for _domain, _curr_chunk in tqdm(chunk.items()): # .items() + domain_ids = 0 + if status[_domain] != 3: + print(f"==== UPDATE dataset for: {_domain} w/ status = {status[_domain]}") + + for scan_id in _curr_chunk: + domain_ids += 1 + if domain_ids > self.chunksize: + print(f"=== UPDATE finished") + break + + itm = self.sample_list[_domain][scan_id] + if scan_id not in self.pid_curr_load[_domain]: + continue + + # Keep the original dataset + if (status[_domain] == 0) or (status[_domain] == 2 and domain_ids <= self.chunksize // 2): + size = self.actual_dataset[glb_idx]['size'] + out_list.extend(self.actual_dataset[glb_idx: glb_idx + size]) # Original dataset + glb_idx += size + continue + + if (status[_domain] == 1 and domain_ids > self.chunksize // 2): + try: + size = self.actual_dataset[glb_idx]['size'] + out_list.extend(self.actual_dataset[glb_idx: glb_idx + size]) # Original dataset + glb_idx += size + continue + except: + print(f"=== Warning (domain_ids={domain_ids}) getting glb_idx={glb_idx} from actual_dataset length={len(self.actual_dataset)}") + # print(self.actual_dataset) + + img, _info = nio.read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out + self.info_by_scan[_domain + '_' + scan_id] = _info + + img_original = np.float32(img) + img = img_original.copy() + + aux = nio.read_nii_bysitk(itm["aux_fid"]) + aux_original = np.float32(aux) + aux = aux_original.copy() + + + # img, self.mean, self.std = self.normalize_op(img) + _, mean, std = self.normalize_op(img) + _, aux_mean, aux_std = self.normalize_op(aux) + + if not self.pseudo: + lb = nio.read_nii_bysitk(itm["lbs_fid"]) + else: + uncertainty_thr = 0.05 # 0.05 + lb_cache = np.load(itm["lbs_fid"], allow_pickle=True).item() + lb = lb_cache['pseudo'].cpu().numpy() # "pseudo": curr_pred, "score":curr_score , "uncertainty" + uncertainty = lb_cache['uncertainty'].cpu().numpy() # Z, C, H, W + uncertainty = np.float32(uncertainty) + + new_lb = np.zeros_like(lb) + for cls in range(self.nclass - 1): + un_mask = (uncertainty[:, cls+1] < uncertainty_thr ) * (cls+1) + new_lb[lb == (cls+1)] = un_mask[lb == (cls+1)] + + lb = new_lb + + lb_original = np.float32(lb) + lb = lb_original.copy() + + # -> H, W, C + img, lb, aux = map(lambda arr: np.transpose(arr, (1, 2, 0)), [img, lb, aux]) + assert img.shape[-1] == lb.shape[-1], f"ASSERT {img.shape} = {lb.shape}" + + # Resize: + if img.shape[1] != self.fineSize[1]: + # H, W, C + res = self.resizer(image=img, mask=lb, image2=aux) + img, lb, aux = res['image'], res['mask'], res['image2'] + + prt_cache = f" {_domain} stat ({domain_ids}/{len(_curr_chunk)}): shape={img.shape}, max={img.max()}, min={img.min()}" + + # Filter vacant slices + if self.phase == "train": + filter = np.any(np.any(img, axis=0), axis=0) + img, lb, aux = img[..., filter], lb[..., filter], aux[..., filter] + + img, lb, aux = self.filter_with_label(img, lb, aux) + + out_list, glb_idx = self.add_to_list(glb_idx, out_list, img, lb, + aux, mean, aux_mean, aux_std, std, _domain, + scan_id, itm["img_fid"]) + + if (domain_ids) % (len(_curr_chunk) // 2) == 0: + print(prt_cache + f", filtered shape={img.shape}, mask max={lb.max()}") + + + # Add various axis view !!! + if self.phase == "train" and self.use_diff_axis_view: + # C, W, H + img, lb, aux = img_original, lb_original, aux_original + # Resize: + if img.shape[1] != self.fineSize[1]: + res = self.resizer(image=img, mask=lb, image2=aux) # assume H, W, (C)<- + img, lb, aux = res['image'], res['mask'], res['image2'] + + img, lb, aux = self.filter_with_label(img, lb, aux) + + out_list, glb_idx = self.add_to_list(glb_idx, out_list, img, lb, + aux, mean, aux_mean, aux_std, std, _domain, + scan_id, itm["img_fid"]) + + del img, lb, aux, img_original, lb_original, aux_original + + del self.actual_dataset + return out_list + + def next_chunk(self, all_samples): + # 0 No update, 1 First half, 2 Second half, 3 All updates Chunk + status = {} + self.last_chunk = copy.deepcopy(self.current_chunk) + for _domain, _sample_list in tqdm(all_samples.items()): + # Default value + status[_domain] = 3 + + # Put all in - validation or small dataset + if ((not self.is_train) or len(_sample_list) < self.chunksize) and not self.debug: + self.current_chunk[_domain] = _sample_list + if _domain not in self.last_chunk: + status[_domain] = 3 # all + else: + status[_domain] = 0 # not updates + print("=== Put all data in for", _domain) + continue + + # chunksize + random.shuffle(self.chunk_pool[_domain]) + if _domain not in self.last_chunk: + status[_domain] = 3 + self.current_chunk[_domain] = self.chunk_pool[_domain][:self.chunksize] + self.chunk_pool[_domain] = self.chunk_pool[_domain][self.chunksize:] + + else: + status[_domain] = self.chunk_id//2 + 1 # 1, 2 + candidate = self.chunk_pool[_domain][:self.chunksize//2] + self.chunk_pool[_domain] = self.chunk_pool[_domain][self.chunksize //2:] + if status[_domain] == 1: + self.current_chunk[_domain][:self.chunksize // 2] = candidate + else: + self.current_chunk[_domain][self.chunksize // 2:] = candidate + + if _domain in self.last_chunk: + self.chunk_pool[_domain] = self.chunk_pool[_domain] + self.last_chunk[_domain] + + self.chunk_id += 1 + + return self.current_chunk, status + + + def add_to_list(self, glb_idx, out_list, img, lb, aux, mean, std, aux_mean, aux_std, _domain, scan_id, file_id): + # now start writing everthing in + c = 3 + + for ii in range(img.shape[-1]): + is_end = False + is_start = False + if ii == 0: + is_start = True + # write the beginning frame + if self.input_window == 3: + _img = img[..., 0: c].copy() + _img[..., 1] = _img[..., 0] + elif self.input_window == 1: + _img = img[..., 0: 0 + 1].copy() + + + elif ii < img.shape[-1] - 1: + if self.input_window == 3: + _img = img[..., ii -1: ii + 2].copy() + elif self.input_window == 1: + _img = img[..., ii: ii + 1].copy() + + else: + is_end = True + if self.input_window == 3: + _img = img[..., ii-2: ii + 1].copy() + _img[..., 0] = _img[..., 1] + elif self.input_window == 1: + _img = img[..., ii: ii+ 1].copy() + + _lb = lb[..., ii: ii + 1] + _aux = aux[..., ii: ii + 1] + + out_list.append( + {"img": _img, "lb":_lb, "aux":_aux, "size": img.shape[-1], + "mean":mean, "std":std, + "aux_mean": aux_mean, "aux_std": aux_std, + "is_start": is_start, "is_end": is_end, + "domain": _domain, "nframe": img.shape[-1], + "scan_id": _domain + "_" + scan_id, + "pid": scan_id, "file_id": file_id, "z_id":ii}) + glb_idx += 1 + + return out_list, glb_idx + + + def get_patch_from_img(self, img_H, img_L, img_L2, crop_size=[320, 320], zslice_dim=2): + # -------------------------------- + # randomly crop the patch + # -------------------------------- + + H, W, _ = img_H.shape + rnd_h = random.randint(0, max(0, H - crop_size[0])) + rnd_w = random.randint(0, max(0, W - crop_size[1])) + + # image = torch.index_select(image, 0, torch.tensor([1])) + if zslice_dim == 2: + patch_H = img_H[rnd_h:rnd_h + crop_size[0], rnd_w:rnd_w + crop_size[1], :] + patch_L = img_L[rnd_h:rnd_h + crop_size[0], rnd_w:rnd_w + crop_size[1], :] + patch_L2 = img_L2[rnd_h:rnd_h + crop_size[0], rnd_w:rnd_w + crop_size[1], :] + elif zslice_dim == 0: + patch_H = img_H[:, rnd_h:rnd_h + crop_size[0], rnd_w:rnd_w + crop_size[1]] + patch_L = img_L[:, rnd_h:rnd_h + crop_size[0], rnd_w:rnd_w + crop_size[1]] + patch_L2 = img_L2[:, rnd_h:rnd_h + crop_size[0], rnd_w:rnd_w + crop_size[1]] + + return patch_H, patch_L, patch_L2 + + + def __len__(self): + """ + copy-paste from basic naive dataset configuration + """ + return len(self.actual_dataset) diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/brain.py b/MRI_recon/code/Frequency-Diffusion/dataset/brain.py new file mode 100644 index 0000000000000000000000000000000000000000..b7958424e373bc61cbdfc52a0b4348d76cef7dd4 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/brain.py @@ -0,0 +1,148 @@ +# Dataloader for abdominal images +import glob +import numpy as np +from .m4_utils import niftiio as nio +from .m4_utils import transform_utils as trans +from .m4_utils.abd_dataset_utils import get_normalize_op +from .m4_utils.transform_albu import get_albu_transforms, get_resize_transforms + +import torch +import os +from pdb import set_trace +from multiprocessing import Process +from .basic import BasicDataset + + +LABEL_NAME = ["bg", "NCR", "ED", "ET"] + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + +def normalize_instance(data, mean=None, std=None, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + if mean is None: + mean = data.mean() + if std is None: + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + +class BrainDataset(BasicDataset): + def __init__(self, mode, base_dir, image_size, + nclass, domains, aux_modality, **kwargs): + """ + Args: + mode: 'train', 'val', 'test', 'test_all' + transforms: naive data augmentations used by default. Photometric transformations slightly better than those configured by Zhang et al. (bigaug) + idx_pct: train-val-test split for source domain + extern_norm_fn: feeding data normalization functions from external, only concerns CT-MR cross domain scenario + """ + self.dataset_key = "brain" + transforms = get_albu_transforms(mode, image_size) + if isinstance(domains, str): + domains = [domains] + + super(BrainDataset, self).__init__(image_size, mode, + transforms, + base_dir, + domains, aux_modality, + nclass=nclass, + LABEL_NAME=LABEL_NAME, + filter_non_labeled=True, + **kwargs) + + def hwc_to_chw(self,img): + img = np.float32(img) + img = np.transpose(img, (2, 0, 1)) # [C, H, W] + img = torch.from_numpy( img.copy() ) + return img + + def perform_trans(self, img, mask, aux): + + T = self.albu_transform if self.is_train else self.test_resizer + buffer = T(image = img, mask=mask, image2=aux) # [0 - 255] + img, mask, aux = buffer['image'], buffer['mask'], buffer['image2'] + if len(mask.shape) == 2: + mask = mask[..., None] + + # if self.is_train: + # img, mask, aux = self.get_patch_from_img(img, mask, aux, crop_size=self.crop_size) # 192 + + return img, mask, aux + + + def __getitem__(self, index): + index = index % len(self.actual_dataset) + curr_dict = self.actual_dataset[index] # numpy + + # ----------------------- Extract Slice ----------------------- + img, mask, aux = curr_dict["img"], curr_dict["lb"], curr_dict["aux"] # H, W, C, [0 - 255] + domain, pid = curr_dict["domain"], curr_dict["pid"] + mean, std = curr_dict['mean'], curr_dict['std'] + aux_mean, aux_std = curr_dict['aux_mean'], curr_dict['aux_std'] + # max, min = img.max(), img.min() + std = 1 if std < 1e-3 else std + + # img = (img - mean) / std + ### 对input image和target image都做(x-mean)/std的归一化操作 + img, img_mean, img_std = normalize_instance(img, eps=1e-6) # mean=mean, std=std, + aux, aux_mean, aux_std = normalize_instance(aux, eps=1e-6) # mean=aux_mean, std=aux_std, + + ### clamp input to ensure training stability. + img = np.clip(img, -6, 6) + aux = np.clip(aux, -6, 6) + + + mask = mask[..., 0] + img, mask, aux = self.perform_trans(img, mask, aux) + img, mask, aux = map(lambda arr: self.hwc_to_chw(arr), [img, mask, aux]) + + img = np.clip(img, -6, 6) + aux = np.clip(aux, -6, 6) + + if self.tile_z_dim > 1 and self.input_window == 1 and self.num_channels == 3 : + img = img.repeat( [ self.tile_z_dim, 1, 1] ) + assert img.ndimension() == 3 + + data = {"img": img, "lb": mask, "aux": aux, + "img_mean": img_mean, "img_std": img_std, + "aux_mean": aux_mean, "aux_std": aux_std, + "is_start": curr_dict["is_start"], + "is_end": curr_dict["is_end"], + "nframe": np.int32(curr_dict["nframe"]), + "scan_id": curr_dict["scan_id"], + "z_id": curr_dict["z_id"], + "file_id": curr_dict["file_id"] + } + + return data + + + diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/celeb.py b/MRI_recon/code/Frequency-Diffusion/dataset/celeb.py new file mode 100644 index 0000000000000000000000000000000000000000..43db1aac54d58ed7cae60033e72e811c3467cc16 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/celeb.py @@ -0,0 +1,61 @@ +from comet_ml import Experiment +import math + + +from torch.utils import data +from pathlib import Path +from torchvision import transforms +from PIL import Image + + + + +class Dataset_Aug1(data.Dataset): + def __init__(self, folder, image_size, exts = ['jpg', 'jpeg', 'png']): + super().__init__() + self.folder = folder + self.image_size = image_size + self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] + + self.transform = transforms.Compose([ + transforms.Resize((int(image_size*1.12), int(image_size*1.12))), + transforms.RandomCrop(image_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Lambda(lambda t: (t * 2) - 1) + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + img = Image.open(path) + img = img.convert('RGB') + return self.transform(img) + + + +class Dataset(data.Dataset): + def __init__(self, folder, image_size, exts=['jpg', 'jpeg', 'png']): + super().__init__() + self.folder = folder + self.image_size = image_size + self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] + + self.transform = transforms.Compose([ + transforms.Resize((int(image_size*1.12), int(image_size*1.12))), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + transforms.Lambda(lambda t: (t * 2) - 1) + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + img = Image.open(path) + img = img.convert('RGB') + return self.transform(img) + \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/dicom_test.py b/MRI_recon/code/Frequency-Diffusion/dataset/dicom_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ece18c9d53a7429805b18cab2c7b98273e5db9a6 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/dicom_test.py @@ -0,0 +1,75 @@ +import pydicom +# pip install pydicom + + +def print_dicom_metadata(file_path): + # Read the DICOM file + dicom_data = pydicom.dcmread(file_path) + + # Print all metadata + for element in dicom_data: + # Retrieve the tag, name, and value + tag = element.tag + name = element.name + value = element.value + + # Handle different types of values + if isinstance(value, pydicom.multival.MultiValue): + # Join MultiValue elements into a single string + value = ", ".join(str(v) for v in value) + + elif isinstance(value, bytes): + # Decode bytes if possible, or represent them as hex + try: + value = value.decode('utf-8') + except UnicodeDecodeError: + value = value.hex() + + # Print the tag, name, and processed value + print(f"{tag} {name}: {value}") + +# Path to your DICOM file +dicom_file_path = "/gamedrive/Datasets/medical/Knee/fastMRI/knee_mri_clinical_seq_batch2/FB_476595____FB,1899398684/study_63e96492/MR2_dd8eb0e8/00031_852687759fd1a2c1.dcm" +dicom_file_path = "/gamedrive/Datasets/medical/Knee/fastMRI/knee_mri_clinical_seq_batch2/FB_476595____FB,1899398684/study_63e96492/MR3_e6e4d154/00013_15898f7eff8d4655.dcm" + + +dicom_data = pydicom.dcmread(dicom_file_path) +# MR2_dd8eb0e8/ MR3_e6e4d154/ MR4_71dd8cd8/ MR5_9419dbc1/ + +# Print metadata +# Extract the Series Description +series_description = dicom_data.get((0x0008, 0x103E), "Series Description not found").value +print(f"Series Description: {series_description}") + +dicom_folder = "/gamedrive/Datasets/medical/Knee/fastMRI/knee_mri_clinical_seq_batch2/" + +dict = {"MR2": [], "MR3":[], "MR4":[], "MR5":[], "MR6":[], "MR7":[], "MR8":[], "MR9":[]} + + +count = 0 +import os, glob +for mrs in glob.glob(f"{dicom_folder}/*/*/*"): + mr_name = mrs.split("/")[-1].split("_")[0] + print(mr_name) + + for root, dirs, files in os.walk(mrs): + for file in files[:1]: + if file.endswith(".dcm"): # Check if the file has a .dcm extension + dicom_file_path = os.path.join(root, file) + + # Read the DICOM file + ds = pydicom.dcmread(dicom_file_path) + + # Extract the Series Description, if available + series_description = ds.get((0x0008, 0x103E), None).value + if series_description: + print(f"File: {file} | Series Description: {series_description}") + else: + print(f"File: {file} | Series Description not found") + dict[mr_name].append(series_description) + count +=1 + if count == 200: + break + +for key, value in dict.items(): + print(f"{key}: {value} ") \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/fastmri.py b/MRI_recon/code/Frequency-Diffusion/dataset/fastmri.py new file mode 100644 index 0000000000000000000000000000000000000000..68e3b22b8a78ef2c314b44a29798bb78ded7e726 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/fastmri.py @@ -0,0 +1,338 @@ +import csv +import os +import random +import xml.etree.ElementTree as etree +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +import pathlib + +import h5py +import numpy as np +import torch +import yaml +from torch.utils.data import Dataset +# from .transforms import build_transforms +from matplotlib import pyplot as plt +from dataset.m4_utils.transform_albu import get_albu_transforms + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +def fetch_dir(key, data_config_file=pathlib.Path("fastmri_dirs.yaml")): + """ + Data directory fetcher. + + This is a brute-force simple way to configure data directories for a + project. Simply overwrite the variables for `knee_path` and `brain_path` + and this function will retrieve the requested subsplit of the data for use. + + Args: + key (str): key to retrieve path from data_config_file. + data_config_file (pathlib.Path, + default=pathlib.Path("fastmri_dirs.yaml")): Default path config + file. + + Returns: + pathlib.Path: The path to the specified directory. + """ + if not data_config_file.is_file(): + default_config = dict( + knee_path="/home/jc3/Data/", + brain_path="/home/jc3/Data/", + ) + with open(data_config_file, "w") as f: + yaml.dump(default_config, f) + + raise ValueError(f"Please populate {data_config_file} with directory paths.") + + with open(data_config_file, "r") as f: + data_dir = yaml.safe_load(f)[key] + + data_dir = pathlib.Path(data_dir) + + if not data_dir.exists(): + raise ValueError(f"Path {data_dir} from {data_config_file} does not exist.") + + return data_dir + + +def et_query( + root: etree.Element, + qlist: Sequence[str], + namespace: str = "http://www.ismrm.org/ISMRMRD", +) -> str: + """ + ElementTree query function. + This can be used to query an xml document via ElementTree. It uses qlist + for nested queries. + Args: + root: Root of the xml to search through. + qlist: A list of strings for nested searches, e.g. ["Encoding", + "matrixSize"] + namespace: Optional; xml namespace to prepend query. + Returns: + The retrieved data as a string. + """ + s = "." + prefix = "ismrmrd_namespace" + + ns = {prefix: namespace} + + for el in qlist: + s = s + f"//{prefix}:{el}" + + value = root.find(s, ns) + if value is None: + raise RuntimeError("Element not found") + + return str(value.text) + +from dataset.m4_utils.transforms import build_transforms + +class SliceDataset(Dataset): + def __init__( + self, + root, + transform, + challenge, + input_normalize="mean_std", + image_size=(128, 128), + sample_rate=1, + mode='train', + debug=True, + ): + self.mode = mode + self.transforms = get_albu_transforms(mode, image_size) + self.input_normalize = input_normalize + + # challenge + if challenge not in ("singlecoil", "multicoil"): + raise ValueError('challenge should be either "singlecoil" or "multicoil"') + self.recons_key = ( + "reconstruction_esc" if challenge == "singlecoil" else "reconstruction_rss" + ) + # transform + self.transform = transform + + self.other_transform = build_transforms("random", + [1], + [1], mode) + self.examples = [] + + self.cur_path = root + print("dataroot = ", root) + if self.mode == "test": + self.csv_file = "./dataset/knee_data_split/singlecoil_train_split_less.csv" + else: + self.csv_file = "./dataset/knee_data_split/singlecoil_" + self.mode + "_split_less.csv" + + with open(self.csv_file, 'r') as f: + reader = csv.reader(f) + + id = 0 + if debug: + reader = list(reader)[:10] + + for row in reader: + pd_metadata, pd_num_slices = self._retrieve_metadata(os.path.join(self.cur_path, row[0] + '.h5')) + + pdfs_metadata, pdfs_num_slices = self._retrieve_metadata(os.path.join(self.cur_path, row[1] + '.h5')) + + for slice_id in range(min(pd_num_slices, pdfs_num_slices)): + self.examples.append( + (os.path.join(self.cur_path, row[0] + '.h5'), os.path.join(self.cur_path, row[1] + '.h5') + , slice_id, pd_metadata, pdfs_metadata, id)) + id += 1 + + if sample_rate < 1: + random.shuffle(self.examples) + num_examples = round(len(self.examples) * sample_rate) + + self.examples = self.examples[0:num_examples] + + self.down_transform = None + + def update_chunk(self): + pass + + def __len__(self): + return len(self.examples) + + def __getitem__(self, i): + + # read pd + pd_fname, pdfs_fname, slice, pd_metadata, pdfs_metadata, id = self.examples[i] + + with h5py.File(pd_fname, "r") as hf: + pd_kspace = hf["kspace"][slice] + pd_mask = np.asarray(hf["mask"]) if "mask" in hf else None + pd_target = hf[self.recons_key][slice] if self.recons_key in hf else None + attrs = dict(hf.attrs) + attrs.update(pd_metadata) + + if self.other_transform is None: + pd_sample = (pd_kspace, pd_mask, pd_target, attrs, pd_fname, slice) + else: + pd_sample = self.other_transform(pd_kspace, pd_mask, pd_target, attrs, pd_fname, slice) + + with h5py.File(pdfs_fname, "r") as hf: + pdfs_kspace = hf["kspace"][slice] + pdfs_mask = np.asarray(hf["mask"]) if "mask" in hf else None + pdfs_target = hf[self.recons_key][slice] if self.recons_key in hf else None + attrs = dict(hf.attrs) + attrs.update(pdfs_metadata) + + + if self.other_transform is None: + pdfs_sample = (pdfs_kspace, pdfs_mask, pdfs_target, attrs, pdfs_fname, slice) + else: + pdfs_sample = self.other_transform(pdfs_kspace, pdfs_mask, pdfs_target, attrs, pdfs_fname, slice) + + + # input size = 1.1693149e-05 7.2921634e-06 7.177928e-05 3.3911466e-08 + # print("input size = ", pdfs_target.mean(), pdfs_target.std(), pdfs_target.max(), pdfs_target.min()) + pdfs_mean = pdfs_sample[2] + pdfs_std = pdfs_sample[3] + pd_mean = pd_sample[2] + pd_std = pd_sample[3] + + pdfs_target = pdfs_sample[1].numpy() + pd_target = pd_sample[1].numpy() + + # print("pdfs_target:", pdfs_target.shape, pdfs_target.max(), pdfs_target.min()) + + + # print("pdf=", pdfs_target.shape, pdfs_target.max(), pdfs_target.min()) + # if self.input_normalize == "mean_std": + # + # # print("std:", pdfs_sample[3]) + # # print("mean:", pdfs_sample[2]) + # + # pdfs_target, pdfs_mean, pdfs_std = normalize_instance(pdfs_target, eps=1e-11) + # pd_target, pd_mean, pd_std = normalize_instance(pd_target, eps=1e-11) + # + # elif self.input_normalize == "min_max": + # pdfs_target = (pdfs_target - pdfs_target.min()) / (pdfs_target.max() - pdfs_target.min()) + # pd_target = (pd_target - pd_target.min()) / (pd_target.max() - pd_target.min()) + # pdfs_mean = 0 + # pdfs_std = 1 + # pd_mean = 0 + # pd_std = 1 + # else: + # raise ValueError(f"Unrecognized input normalization: {self.input_normalize}") + + + # return (pd_sample, pdfs_sample, id) + pdfs_main = True # PDWI as the auxiliary and FS-PDWI as the target + + if pdfs_main: + img = pdfs_target + aux = pd_target + img_mean = pdfs_mean + img_std = pdfs_std + aux_mean = pd_mean + aux_std = pd_std + else: + img = pd_target + aux = pdfs_target + img_mean = pd_mean + img_std = pd_std + aux_mean = pdfs_mean + aux_std = pdfs_std + + + data_dict = self.transforms(image=img, image2=aux) + img = data_dict['image'] + aux = data_dict['image2'] + + + # print("===img===:", img.shape, aux.shape, img.max(), img.min()) # (320, 320) (320, 320) + # print("===img_mean===:", img_mean, aux_mean) # 0.0 0.0 + + img = np.expand_dims(img, axis=0) + aux = np.expand_dims(aux, axis=0) + + data = {"img": img, "aux": aux, + "img_mean": img_mean, "img_std": img_std, + "aux_mean": aux_mean, "aux_std": aux_std, + } + return data + + + def _retrieve_metadata(self, fname): + with h5py.File(fname, "r") as hf: + et_root = etree.fromstring(hf["ismrmrd_header"][()]) + + enc = ["encoding", "encodedSpace", "matrixSize"] + enc_size = ( + int(et_query(et_root, enc + ["x"])), + int(et_query(et_root, enc + ["y"])), + int(et_query(et_root, enc + ["z"])), + ) + rec = ["encoding", "reconSpace", "matrixSize"] + recon_size = ( + int(et_query(et_root, rec + ["x"])), + int(et_query(et_root, rec + ["y"])), + int(et_query(et_root, rec + ["z"])), + ) + + lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"] + enc_limits_center = int(et_query(et_root, lims + ["center"])) + enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1 + + padding_left = enc_size[1] // 2 - enc_limits_center + padding_right = padding_left + enc_limits_max + + num_slices = hf["kspace"].shape[0] + + metadata = { + "padding_left": padding_left, + "padding_right": padding_right, + "encoding_size": enc_size, + "recon_size": recon_size, + } + + return metadata, num_slices + + +def build_dataset( mode='train', image_size=128, sample_rate=1): + assert mode in ['train', 'val', 'test'], 'unknown mode' + # transforms = build_transforms(args, mode) + + return SliceDataset(os.path.join(args.root_path, 'singlecoil_' + mode), image_size, 'singlecoil', sample_rate=sample_rate, mode=mode) diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/fsm_dataloaders/__init__.py b/MRI_recon/code/Frequency-Diffusion/dataset/fsm_dataloaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/fsm_dataloaders/hybrid_sparse.py b/MRI_recon/code/Frequency-Diffusion/dataset/fsm_dataloaders/hybrid_sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..42e4a7e33c2204c13a1c4509897baf19e1fb07f1 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/fsm_dataloaders/hybrid_sparse.py @@ -0,0 +1,156 @@ +from __future__ import print_function, division +import numpy as np +from glob import glob +import random +from skimage import transform + +import torch +from torch.utils.data import Dataset + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', transform=None): + + super().__init__() + self._base_dir = base_dir + self.im_ids = [] + self.images = [] + self.gts = [] + + if split=='train': + self._image_dir = self._base_dir + imagelist = glob(self._image_dir+"/*_ct.png") + imagelist=sorted(imagelist) + for image_path in imagelist: + gt_path = image_path.replace('ct', 't1') + self.images.append(image_path) + self.gts.append(gt_path) + + + elif split=='test': + self._image_dir = self._base_dir + imagelist = glob(self._image_dir + "/*_ct.png") + imagelist=sorted(imagelist) + for image_path in imagelist: + gt_path = image_path.replace('ct', 't1') + self.images.append(image_path) + self.gts.append(gt_path) + + self.transform = transform + + assert (len(self.images) == len(self.gts)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.images))) + + def __len__(self): + return len(self.images) + + + def __getitem__(self, index): + img_in, img, target_in, target= self._make_img_gt_point_pair(index) + sample = {'image_in': img_in, 'image':img, 'target_in': target_in, 'target': target} + # print("image in:", img_in.shape) + + if self.transform is not None: + sample = self.transform(sample) + + return sample + + + def _make_img_gt_point_pair(self, index): + # Read Image and Target + + # the default setting (i.e., rawdata.npz) is 4X64P + dd = np.load(self.images[index].replace('.png', '_raw_4X64P.npz')) + # print("images range:", dd['fbp'].max(), dd['ct'].max(), dd['under_t1'].max(), dd['t1'].max()) + _img_in = dd['fbp'] + _img_in[_img_in>0.6]=0.6 + _img_in = _img_in/0.6 + + _img = dd['ct'] + _img =(_img/1000*0.192+0.192) + _img[_img<0.0]=0.0 + _img[_img>0.6]=0.6 + _img = _img/0.6 + + _target_in = dd['under_t1'] + _target = dd['t1'] + + return _img_in, _img, _target_in, _target + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 400, 400 + crop_size = 384 + pad_size = (400-384)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + return {'ct_in': img_in, + 'ct': img, + 'mri_in': target_in, + 'mri': target} diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/fsm_dataloaders/kspace_subsample.py b/MRI_recon/code/Frequency-Diffusion/dataset/fsm_dataloaders/kspace_subsample.py new file mode 100644 index 0000000000000000000000000000000000000000..fb5b5694d8fee8b35ba8394fae98fe2d3aa25759 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/fsm_dataloaders/kspace_subsample.py @@ -0,0 +1,287 @@ +""" +2023/10/16, +preprocess kspace data with the undersampling mask in the fastMRI project. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +## mri related +def mri_fourier_transform_2d(image, mask): + ''' + image: input tensor [B, H, W, C] + mask: mask tensor [H, W] + ''' + spectrum = torch.fft.fftn(image, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + spectrum = torch.fft.fftshift(spectrum, dim=(1, 2)) + # Downsample k-space + masked_spectrum = spectrum * mask[None, :, :, None] + return spectrum, masked_spectrum + + +## mri related +def mri_inver_fourier_transform_2d(spectrum): + ''' + image: input tensor [B, H, W, C] + ''' + spectrum = torch.fft.ifftshift(spectrum, dim=(1, 2)) + image = torch.fft.ifftn(spectrum, dim=(1, 2), norm='ortho') + + return image + + +def add_gaussian_noise(kspace, snr): + ### 根据SNR确定noise的放大比例 + num_pixels = kspace.shape[0]*kspace.shape[1]*kspace.shape[2]*kspace.shape[3] + psr = torch.sum(torch.abs(kspace.real)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + noise_r = torch.randn_like(kspace.real)*np.sqrt(pnr) + + psim = torch.sum(torch.abs(kspace.imag)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = torch.randn_like(kspace.imag)*np.sqrt(pnim) + + noise = noise_r + 1j*noise_im + noisy_kspace = kspace + noise + + return noisy_kspace + + +def mri_fft(raw_mri, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + spectrum = torch.fft.fftn(mri, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + kspace = torch.fft.fftshift(spectrum, dim=(1, 2)) + + if _SNR > 0: + noisy_kspace = add_gaussian_noise(kspace, _SNR) + else: + noisy_kspace = kspace + + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + + return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ + kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1) + + + +def undersample_mri(raw_mri, _MRIDOWN, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + if _MRIDOWN == "4X": + mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 + elif _MRIDOWN == "8X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + + ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + + shape = [240, 240, 1] + mask = ff(shape, seed=1337) + mask = mask[:, :, 0] # [1, 240] + # print("mask:", mask.shape) + # print("original MRI:", mri) + + # print("original MRI:", mri.shape) + ### under-sample the kspace data. + kspace, masked_kspace = mri_fourier_transform_2d(mri, mask) + ### add low-field noise to the kspace data. + if _SNR > 0: + noisy_kspace = add_gaussian_noise(masked_kspace, _SNR) + else: + noisy_kspace = masked_kspace + + ### conver the corrupted kspace data back to noisy MRI image. + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + + return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ + kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1), mask.unsqueeze(-1) + + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/h5_test.py b/MRI_recon/code/Frequency-Diffusion/dataset/h5_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2eeb09503b238a43f94ae98589dd9f7152fc65 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/h5_test.py @@ -0,0 +1,16 @@ +import h5py, glob +import xml.etree.ElementTree as ET + + +h5_path = "/Users/haochen/Downloads/singlecoil_test/" # file1000056.h5 + +for h5 in glob.glob(h5_path + "*.h5"): + with h5py.File(h5, "r") as hf: + print("Keys: %s" % hf.keys()) + print("Attrs: %s" % hf.attrs.items()) + print("kspace shape:", hf['kspace'].shape) + # print("ismrmrd_header shape:", hf['ismrmrd_header'].shape) + for key, value in hf.attrs.items(): + print(f" {key}: {value}") + + print() diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/knee_data_split/singlecoil_train_split_less.csv b/MRI_recon/code/Frequency-Diffusion/dataset/knee_data_split/singlecoil_train_split_less.csv new file mode 100644 index 0000000000000000000000000000000000000000..d85707318750900b14a6e7100541242a60b7a310 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/knee_data_split/singlecoil_train_split_less.csv @@ -0,0 +1,227 @@ +file1000685,file1000568,0.301723929779229 +file1002273,file1000481,0.302226224199571 +file1000472,file1000142,0.304272730770318 +file1002186,file1000863,0.304812175768496 +file1002385,file1002518,0.305357274240413 +file1000981,file1000129,0.305533361411383 +file1001320,file1001948,0.306821514316368 +file1000633,file1002243,0.306892354331709 +file1001872,file1001294,0.308345907393103 +file1001474,file1001830,0.310481695157561 +file1001005,file1001283,0.310497722435023 +file1001690,file1001519,0.310709448786299 +file1002469,file1001811,0.31193137253455 +file1000914,file1000242,0.31237190359308 +file1002284,file1002012,0.315366393843169 +file1001721,file1001328,0.31735122361847 +file1000807,file1002334,0.320096908959039 +file1001944,file1002335,0.320272061156991 +file1002090,file1002431,0.320351887633851 +file1000499,file1002063,0.320786426659383 +file1001362,file1000509,0.32175341740359 +file1001421,file1000597,0.324291432700032 +file1000349,file1000321,0.324545110048573 +file1002123,file1001235,0.327142348994532 +file1001867,file1002086,0.328624781732941 +file1001007,file1001027,0.330759860300298 +file1001915,file1000088,0.331499371283099 +file1001661,file1000313,0.331905252950291 +file1000383,file1000307,0.339998107225229 +file1000116,file1000632,0.34069458535013 +file1002303,file1000173,0.343821267871409 +file1000306,file1001277,0.344751178043605 +file1000003,file1001922,0.346138116633394 +file1000109,file1000143,0.347632265547478 +file1001999,file1000115,0.348248659775587 +file1000089,file1000326,0.348964657514049 +file1001205,file1002232,0.349375610862454 +file1000557,file1000619,0.351305005151048 +file1001823,file1000778,0.352076809462453 +file1000806,file1001130,0.352659078122633 +file1000365,file1000351,0.352772816610486 +file1002374,file1001778,0.352974481603711 +file1002516,file1001910,0.359896103026675 +file1001200,file1000931,0.360070003966827 +file1001479,file1000952,0.360424533696936 +file1000850,file1001942,0.362632797518558 +file1001426,file1002143,0.363271909822866 +file1001304,file1001333,0.36404737582222 +file1000390,file1000518,0.364744579516818 +file1000830,file1002096,0.365897427529429 +file1000794,file1001856,0.365973692948894 +file1001266,file1001327,0.366395851089761 +file1001692,file1002352,0.36655953875445 +file1001564,file1001024,0.367284385415205 +file1001861,file1002050,0.36783497787384 +file1002066,file1002361,0.367964419694875 +file1001613,file1002087,0.368231014746024 +file1001931,file1000220,0.368847112914793 +file1000339,file1000554,0.370123905662701 +file1000754,file1002208,0.37031588493778 +file1001067,file1001956,0.371313060558732 +file1000101,file1001053,0.372141932838775 +file1002520,file1002409,0.372501194473693 +file1001459,file1001615,0.373295536945146 +file1001673,file1000508,0.376416667681519 +file1002201,file1001228,0.376680033570078 +file1000058,file1002449,0.376927627737029 +file1001748,file1001042,0.378067114701689 +file1001941,file1000376,0.37841176147662 +file1000801,file1002545,0.378423759459738 +file1000010,file1000535,0.38111194591455 +file1000882,file1002154,0.382223600234592 +file1001694,file1001297,0.382545161354354 +file1001992,file1002456,0.382664563820782 +file1001666,file1001773,0.382892588770697 +file1001629,file1002514,0.383417073960824 +file1002113,file1000738,0.385439884728523 +file1002221,file1000569,0.385903801966773 +file1002296,file1002117,0.387319754665673 +file1000693,file1001945,0.387855926202209 +file1001410,file1000223,0.391284037867147 +file1002071,file1001425,0.391497653794399 +file1002325,file1001259,0.391913965917762 +file1002430,file1001969,0.392256443856501 +file1002462,file1000708,0.393161981208355 +file1002358,file1001888,0.39427809496515 +file1000485,file1000753,0.395316199436001 +file1002357,file1001973,0.39564210237905 +file1002130,file1002041,0.395978941103639 +file1002569,file1000097,0.397496127623486 +file1002264,file1000148,0.397630184088734 +file1002381,file1001401,0.398105992102355 +file1000289,file1000585,0.399527637723015 +file1002368,file1001723,0.400243022234875 +file1002342,file1001319,0.400431803928825 +file1002170,file1001226,0.400632448147846 +file1001385,file1001758,0.400855988878681 +file1001732,file1002541,0.40091828863264 +file1001102,file1000762,0.400923140595936 +file1001470,file1000181,0.401353492516182 +file1000400,file1000884,0.401562860630016 +file1002293,file1002523,0.401800994807451 +file1000728,file1001654,0.402763341041675 +file1000582,file1001491,0.403451830806034 +file1000586,file1001521,0.403648293267187 +file1002287,file1001770,0.405194821414496 +file1000371,file1000159,0.405999000381268 +file1002356,file1002064,0.406519210876811 +file1000324,file1000590,0.407593694425997 +file1001622,file1001710,0.40759525378577 +file1002037,file1000403,0.407814136488744 +file1002444,file1000743,0.40943197761463 +file1001175,file1002088,0.410423663035312 +file1001391,file1000540,0.410854355646853 +file1002133,file1001186,0.411248429534111 +file1001229,file1001630,0.411355571792039 +file1002283,file1000402,0.411836769927671 +file1000627,file1000161,0.412089060388579 +file1001701,file1001402,0.412854774524637 +file1000795,file1000452,0.413448916432685 +file1000354,file1000947,0.41459642292987 +file1002043,file1002505,0.414863932355455 +file1001285,file1001113,0.418183757940871 +file1000170,file1001832,0.419441549204313 +file1002399,file1001500,0.419905873946513 +file1002439,file1000177,0.42054051043224 +file1001656,file1001217,0.420597020703942 +file1000296,file1000065,0.420845042251081 +file1000626,file1001623,0.42087934790355 +file1001767,file1000760,0.422315537515139 +file1000467,file1001246,0.422371268999111 +file1001033,file1000611,0.42425275873442 +file1002304,file1000221,0.425602179771197 +file1001737,file1001141,0.425716789218234 +file1001565,file1000559,0.426158561043574 +file1000249,file1000643,0.426541100077021 +file1002014,file1001109,0.426587840438723 +file1002006,file1000790,0.427829459781438 +file1000193,file1000750,0.428103808477214 +file1001993,file1001110,0.428186367615143 +file1002094,file1001814,0.428868578868176 +file1000098,file1001420,0.428968675677784 +file1000336,file1000211,0.430347427208789 +file1001498,file1002568,0.43204475404071 +file1001671,file1001106,0.432215802861284 +file1000426,file1002386,0.43283446816702 +file1001520,file1002481,0.434867670495723 +file1002189,file1001432,0.434924370194975 +file1001390,file1002554,0.435313848731387 +file1002166,file1001982,0.435387512979012 +file1001120,file1001006,0.435594761785839 +file1000149,file1001985,0.436289528591294 +file1001632,file1001008,0.436682374331417 +file1002567,file1001155,0.437221000601772 +file1000434,file1002195,0.438098100114814 +file1002532,file1001048,0.438500899539101 +file1001605,file1000927,0.438686659342641 +file1000479,file1000120,0.439587267995034 +file1002473,file1001388,0.439594997597548 +file1001108,file1002228,0.440528754793898 +file1002099,file1002056,0.440776843467602 +file1000191,file1002127,0.441114509542672 +file1000875,file1002494,0.441378135507993 +file1002161,file1000002,0.441912476744187 +file1002269,file1001220,0.442742296865228 +file1001295,file1001355,0.4435162405589 +file1001659,file1001023,0.444686151316673 +file1001857,file1001378,0.447500830900898 +file1001183,file1001370,0.447782748040587 +file1000428,file1000859,0.448328910257083 +file1000588,file1002227,0.448650488897259 +file1001098,file1000486,0.448862467740607 +file1001288,file1000408,0.450363676957042 +file1002097,file1001210,0.451126832474666 +file1000216,file1001082,0.451550143520946 +file1001746,file1001642,0.451781042569196 +file1002388,file1000204,0.451940333555972 +file1000021,file1000560,0.452234621797968 +file1000489,file1001545,0.452796032302523 +file1001116,file1000883,0.453096911915119 +file1001372,file1000561,0.45532542913335 +file1001276,file1000424,0.45534174289324 +file1000974,file1002098,0.455371894001872 +file1002566,file1002044,0.455937677517583 +file1000262,file1002046,0.456056330767294 +file1001619,file1001342,0.456559091350965 +file1000045,file1001616,0.457599407743834 +file1001468,file1002115,0.458095965024278 +file1001061,file1000233,0.460561351667266 +file1000558,file1000100,0.461094222462111 +file1000605,file1000691,0.461429521647285 +file1000640,file1000384,0.463383466503099 +file1000410,file1001358,0.463452482427773 +file1000851,file1001014,0.463558384057952 +file1001092,file1000138,0.463591264436099 +file1000061,file1002049,0.465778207162619 +file1001206,file1000983,0.466701211830884 +file1000256,file1000475,0.466865377968187 +file1002434,file1001387,0.467154181996099 +file1001036,file1000210,0.470404279499276 +file1001540,file1001860,0.472822271037545 +file1001244,file1001154,0.475076170733515 +file1000131,file1001526,0.475459563440874 +file1000180,file1002045,0.476814451110009 +file1001837,file1000637,0.478851985878026 +file1002425,file1001891,0.481451070031007 +file1001056,file1000682,0.482320170742015 +file1002276,file1000777,0.483452141843029 +file1001139,file1002544,0.487462418948035 +file1000548,file1001257,0.488098081542811 +file1000188,file1001286,0.488423105111001 +file1001879,file1000999,0.488449105381724 +file1001062,file1000231,0.48930683373911 +file1000040,file1001873,0.492070802214623 +file1002286,file1000066,0.493213986773381 +file1002474,file1002563,0.501584439120211 +file1000967,file1000563,0.502066261411662 +file1001307,file1002048,0.50460435259807 +file1000483,file1001699,0.511819026566198 +file1001528,file1000285,0.512629017841038 +file1001742,file1002371,0.513805213204644 +file1002397,file1000592,0.515406473057 +file1000069,file1000510,0.528220553613126 +file1001087,file1001300,0.536510449049583 +file1001991,file1000836,0.538145797125916 +file1001382,file1001806,0.538539506621535 +file1000111,file1001189,0.557690760784602 diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/knee_data_split/singlecoil_val_split_less.csv b/MRI_recon/code/Frequency-Diffusion/dataset/knee_data_split/singlecoil_val_split_less.csv new file mode 100644 index 0000000000000000000000000000000000000000..a1cbac5537562063359f4ac3e0985de51cb989b2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/knee_data_split/singlecoil_val_split_less.csv @@ -0,0 +1,45 @@ +file1000323,file1002538,0.30754967523156 +file1001458,file1001566,0.310512744537048 +file1000885,file1001059,0.318226346221521 +file1000464,file1000196,0.321465466968232 +file1000314,file1000178,0.327505552363568 +file1001163,file1001289,0.328954963947692 +file1000033,file1001191,0.330925609207301 +file1000976,file1000990,0.344036229323198 +file1001930,file1001834,0.345994076497818 +file1002546,file1001344,0.351762252794677 +file1000277,file1001429,0.353297786572139 +file1001893,file1001262,0.358064285890878 +file1000926,file1002067,0.360639004205491 +file1001650,file1002002,0.362186928073579 +file1001184,file1001655,0.362592305723707 +file1001497,file1001338,0.365599407221502 +file1001202,file1001365,0.3844323497275 +file1001126,file1002340,0.388929627976346 +file1001339,file1000291,0.391300537691403 +file1002187,file1001862,0.39883786878841 +file1000041,file1000591,0.39896683485823 +file1001064,file1001850,0.399687813966601 +file1001331,file1002214,0.400340820924839 +file1000831,file1000528,0.403582747590964 +file1000769,file1000538,0.405298051020298 +file1000182,file1001968,0.407646172205036 +file1002382,file1001651,0.410749052045234 +file1000660,file1000476,0.415423894745454 +file1002570,file1001726,0.424622351472032 +file1001585,file1000858,0.426738511964108 +file1000190,file1000593,0.428080574167047 +file1001170,file1001090,0.429987089825525 +file1002252,file1001440,0.432038842370013 +file1000697,file1001144,0.432558506761396 +file1001077,file1000000,0.441922503777368 +file1001381,file1001119,0.455418270809002 +file1001759,file1001851,0.460824505737749 +file1000635,file1002389,0.465674267492171 +file1001668,file1001689,0.467330511330772 +file1001221,file1000818,0.469630000354232 +file1001298,file1002145,0.473526387887779 +file1001763,file1001938,0.47398893150184 +file1001444,file1000942,0.48507438696692 +file1000735,file1002007,0.496530240691134 +file1000477,file1000280,0.528508000547834 diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/kspace.py b/MRI_recon/code/Frequency-Diffusion/dataset/kspace.py new file mode 100644 index 0000000000000000000000000000000000000000..dc79b77ed1e78f86ba46d55364d84cfa449060d0 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/kspace.py @@ -0,0 +1,34 @@ +import torch +from torch import nn +import os +import cv2 +import gc +import numpy as np +from scipy.io import * +from scipy.fftpack import * + + + +# Fourier Transform +def fft_map(x): + fft_x = torch.fft.fftn(x) + fft_x_real = fft_x.real + fft_x_imag = fft_x.imag + + return fft_x_real, fft_x_imag + + +def undersample_kspace(x, mask, is_noise, noise_level, noise_var): + + fft = fft2(x) + fft = fftshift(fft) + fft = fft * mask + + if is_noise: + raise NotImplementedError + fft = fft + generate_gaussian_noise(fft, noise_level, noise_var) + + fft = ifftshift(fft) + x = ifft2(fft) + + return x \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/kspace_subsample.py b/MRI_recon/code/Frequency-Diffusion/dataset/kspace_subsample.py new file mode 100644 index 0000000000000000000000000000000000000000..2634efacb70f129d616c385d17a3c8577ee9f9d4 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/kspace_subsample.py @@ -0,0 +1,379 @@ +""" +2023/10/16, +preprocess kspace data with the undersampling mask in the fastMRI project. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +## mri related +def mri_fourier_transform_2d(image, mask): + ''' + image: input tensor [B, H, W, C] + mask: mask tensor [H, W] + ''' + spectrum = torch.fft.fftn(image, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + spectrum = torch.fft.fftshift(spectrum, dim=(1, 2)) + # Downsample k-space + masked_spectrum = spectrum * mask[None, :, :, None] + return spectrum, masked_spectrum + + +## mri related +def mri_inver_fourier_transform_2d(spectrum): + ''' + image: input tensor [B, H, W, C] + ''' + spectrum = torch.fft.ifftshift(spectrum, dim=(1, 2)) + image = torch.fft.ifftn(spectrum, dim=(1, 2), norm='ortho') + # image = torch.fft.fftshift(image, dim=[1, 2]) + + return image + + +def add_gaussian_noise(kspace, snr): + ### 根据SNR确定noise的放大比例 + num_pixels = kspace.shape[0]*kspace.shape[1]*kspace.shape[2]*kspace.shape[3] + psr = torch.sum(torch.abs(kspace.real)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + noise_r = torch.randn_like(kspace.real)*np.sqrt(pnr) + + psim = torch.sum(torch.abs(kspace.imag)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = torch.randn_like(kspace.imag)*np.sqrt(pnim) + + noise = noise_r + 1j*noise_im + noisy_kspace = kspace + noise + + return noisy_kspace + + +# def mri_fft(raw_mri, _SNR): +# mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) +# spectrum = torch.fft.fftn(mri, dim=(1, 2), norm='ortho') +# # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum +# kspace = torch.fft.fftshift(spectrum, dim=(1, 2)) + +# if _SNR > 0: +# noisy_kspace = add_gaussian_noise(kspace, _SNR) +# else: +# noisy_kspace = kspace + +# noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) +# noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + +# return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ +# kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1) + +def mri_fft_m4raw(lq_mri, hq_mri): + # breakpoint() + lq_mri = torch.tensor(lq_mri[0])[None, :, :, None].to(torch.float32) + lq_mri_spectrum = torch.fft.fftn(lq_mri, dim=(1, 2), norm='ortho') + lq_mri_spectrum = torch.fft.fftshift(lq_mri_spectrum, dim=(1, 2)) + + lq_mri = mri_inver_fourier_transform_2d(lq_mri_spectrum) + lq_mri = torch.cat([torch.real(lq_mri), torch.imag(lq_mri)], dim=-1) + lq_kspace = torch.cat([torch.real(lq_mri_spectrum), torch.imag(lq_mri_spectrum)], dim=-1) + + + hq_mri = torch.tensor(hq_mri[0])[None, :, :, None].to(torch.float32) + hq_mri_spectrum = torch.fft.fftn(hq_mri, dim=(1, 2), norm='ortho') + hq_mri_spectrum = torch.fft.fftshift(hq_mri_spectrum, dim=(1, 2)) + + hq_mri = mri_inver_fourier_transform_2d(hq_mri_spectrum) + hq_mri = torch.cat([torch.real(hq_mri), torch.imag(hq_mri)], dim=-1) + hq_kspace = torch.cat([torch.real(hq_mri_spectrum), torch.imag(hq_mri_spectrum)], dim=-1) + + # breakpoint() + return lq_kspace[0], lq_mri[0].permute(2, 0, 1), \ + hq_kspace[0], hq_mri[0].permute(2, 0, 1) + +def mri_fft(raw_mri, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + spectrum = torch.fft.fftn(mri, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + kspace = torch.fft.fftshift(spectrum, dim=(1, 2)) + + if _SNR > 0: + noisy_kspace = add_gaussian_noise(kspace, _SNR) + else: + noisy_kspace = kspace + + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.cat([torch.real(noisy_mri), torch.imag(noisy_mri)], dim=-1) + noisy_kspace = torch.cat([torch.real(noisy_kspace), torch.imag(noisy_kspace)], dim=-1) + + raw_ksapce = torch.cat([torch.real(kspace), torch.imag(kspace)], dim=-1) + raw_mri = mri_inver_fourier_transform_2d(kspace) + raw_mri = torch.cat([torch.real(raw_mri), torch.imag(raw_mri)], dim=-1) + + # breakpoint() + return noisy_kspace[0], noisy_mri[0].permute(2, 0, 1), \ + raw_ksapce[0], raw_mri[0].permute(2, 0, 1) + +def undersample_mri(raw_mri, _MRIDOWN, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + raw_spectrum = torch.fft.fftn(mri, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + raw_kspace = torch.fft.fftshift(raw_spectrum, dim=(1, 2)) + + if not _MRIDOWN == "0X": + if _MRIDOWN == "4X": + mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 + elif _MRIDOWN == "8X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + + ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + shape = [240, 240, 1] + mask = ff(shape, seed=1337) + mask = mask[:, :, 0] # [1, 240] + + ### under-sample the kspace data. + kspace, masked_kspace = mri_fourier_transform_2d(mri, mask) + + ### add low-field noise to the kspace data. + if _SNR > 0: + noisy_kspace = add_gaussian_noise(masked_kspace, _SNR) + else: + noisy_kspace = masked_kspace + + else: + if _SNR > 0: + noisy_kspace = add_gaussian_noise(raw_kspace, _SNR) + else: + noisy_kspace = raw_kspace + + mask = torch.ones([1,240]) + # breakpoint() + ### conver the corrupted kspace data back to noisy MRI image. + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.cat([torch.real(noisy_mri), torch.imag(noisy_mri)], dim=-1) + + noisy_kspace_ = torch.cat([torch.real(noisy_kspace), torch.imag(noisy_kspace)], dim=-1) + + raw_mri = mri_inver_fourier_transform_2d(raw_kspace) + raw_mri = torch.cat([torch.real(raw_mri), torch.imag(raw_mri)], dim=-1) + raw_kspace = torch.cat([torch.real(raw_kspace), torch.imag(raw_kspace)], dim=-1) + + return noisy_kspace_[0], noisy_mri[0].permute(2, 0, 1), \ + raw_kspace[0], raw_mri[0].permute(2, 0, 1), mask.unsqueeze(-1) + +# def undersample_mri(raw_mri, _MRIDOWN, _SNR): +# mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) +# if _MRIDOWN == "4X": +# mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 +# elif _MRIDOWN == "8X": +# mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + +# ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + +# shape = [240, 240, 1] +# mask = ff(shape, seed=1337) +# mask = mask[:, :, 0] # [1, 240] +# # print("mask:", mask.shape) +# # print("original MRI:", mri) + +# # print("original MRI:", mri.shape) +# ### under-sample the kspace data. +# kspace, masked_kspace = mri_fourier_transform_2d(mri, mask) +# ### add low-field noise to the kspace data. +# if _SNR > 0: +# noisy_kspace = add_gaussian_noise(masked_kspace, _SNR) +# else: +# noisy_kspace = masked_kspace + +# ### conver the corrupted kspace data back to noisy MRI image. +# noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) +# noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + +# return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ +# kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1), mask.unsqueeze(-1) + + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/m4_utils.py b/MRI_recon/code/Frequency-Diffusion/dataset/m4_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0d76246b0c8fb757b6b589b7f889e0316696d0 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/m4_utils.py @@ -0,0 +1,272 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +import numpy as np + +def complex_mul(x, y): + """ + Complex multiplication. + + This multiplies two complex tensors assuming that they are both stored as + real arrays with the last dimension being the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == y.shape[-1] == 2 + re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] + im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] + + return torch.stack((re, im), dim=-1) + + +def complex_conj(x): + """ + Complex conjugate. + + This applies the complex conjugate assuming that the input array has the + last dimension as the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == 2 + + return torch.stack((x[..., 0], -x[..., 1]), dim=-1) + + +# def fft2c(data): +# """ +# Apply centered 2 dimensional Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The FFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# data = torch.fft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + +# def ifft2c(data): +# """ +# Apply centered 2-dimensional Inverse Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The IFFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# # data = torch.ifft(data, 2, normalized=True) +# data = torch.fft.ifft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + + +def fft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2 dimensional Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.fft``. + + Returns: + The FFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.fftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.ifft``. + + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + + + + +def complex_abs(data): + """ + Compute the absolute value of a complex valued input tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Absolute value of data. + """ + assert data.size(-1) == 2 + + return (data ** 2).sum(dim=-1).sqrt() + + + +def complex_abs_numpy(data): + assert data.shape[-1] == 2 + + return np.sqrt(np.sum(data ** 2, axis=-1)) + + +def complex_abs_sq(data):#multi coil + """ + Compute the squared absolute value of a complex tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Squared absolute value of data. + """ + assert data.size(-1) == 2 + return (data ** 2).sum(dim=-1) + + +# Helper functions + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data + """ + data = data.numpy() + return data[..., 0] + 1j * data[..., 1] diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/m4raw_dataloader.py b/MRI_recon/code/Frequency-Diffusion/dataset/m4raw_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..0df4823e792595b4fcf066350c62ea30c02ec443 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/m4raw_dataloader.py @@ -0,0 +1,488 @@ + +from __future__ import print_function, division +from typing import Dict, NamedTuple, Optional, Sequence, Tuple, Union +import sys +sys.path.append('.') +from glob import glob +import os +os.environ['OPENBLAS_NUM_THREADS'] = '1' +import numpy as np +import torch +from torch.utils.data import Dataset + +import h5py +from matplotlib import pyplot as plt +from dataloaders.utils import ifft2c, fft2c, complex_abs +from dataloaders.kspace_subsample import create_mask_for_mask_type + +import argparse +from torch.utils.data import DataLoader +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + +# def normal(x): +# y = np.zeros_like(x) +# for i in range(y.shape[0]): +# x_min = x[i].min() +# x_max = x[i].max() +# y[i] = (x[i] - x_min)/(x_max-x_min) +# return y + + + +def undersample_mri(kspace, _MRIDOWN): + # print("kspace shape:", kspace.shape) ## [18, 4, 256, 256, 2] + + if _MRIDOWN == "4X": + mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 + elif _MRIDOWN == "8X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + + ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + + shape = [256, 256, 1] + mask = ff(shape, seed=1337) ## [1, 256, 1] + + mask = mask[:, :, 0] # [1, 256] + + masked_kspace = kspace * mask[None, None, :, :, None] + + return masked_kspace, mask.unsqueeze(-1) + +def apply_mask(data, mask_func, seed=None, padding=None): + """ + Subsample given k-space by multiplying with a mask. + + Args: + data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where + dimensions -3 and -2 are the spatial dimensions, and the final dimension has size + 2 (for complex values). + mask_func (callable): A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed (int or 1-d array_like, optional): Seed for the random number generator. + + Returns: + (tuple): tuple containing: + masked data (torch.Tensor): Subsampled k-space data + mask (torch.Tensor): The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + +def complex_center_crop(data, shape): + """ + Apply a center crop to the input image or batch of complex images. + + Args: + data (torch.Tensor): The complex input tensor to be center cropped. It + should have at least 3 dimensions and the cropping is applied along + dimensions -3 and -2 and the last dimensions should have a size of + 2. + shape (int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image + """ + assert 0 < shape[0] <= data.shape[-3] + assert 0 < shape[1] <= data.shape[-2] + + w_from = (data.shape[-3] - shape[0]) // 2 #80 + h_from = (data.shape[-2] - shape[1]) // 2 #80 + w_to = w_from + shape[0] #240 + h_to = h_from + shape[1] #240 + + return data[..., w_from:w_to, h_from:h_to, :] + + + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Args: + data (np.array): Input numpy array. + + Returns: + torch.Tensor: PyTorch version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data) + +def rss(data, dim=0): + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Args: + data (torch.Tensor): The input tensor + dim (int): The dimensions along which to apply the RSS transform + + Returns: + torch.Tensor: The RSS value. + """ + return torch.sqrt((data ** 2).sum(dim)) + +def read_h5(file_name, _MRIDOWN): + crop_size=[240,240] + + hf = h5py.File(file_name) + volume_kspace = hf['kspace'][()] + + slice_kspace = volume_kspace + slice_kspace = to_tensor(slice_kspace) + import imageio as io + + + + # masked_kspace, mask = apply_mask(slice_kspace, mask_func, seed=123456) + masked_kspace, mask = undersample_mri(slice_kspace, _MRIDOWN) + lq_image = ifft2c(masked_kspace) + lq_image = complex_center_crop(lq_image, crop_size) + lq_image = complex_abs(lq_image) + lq_image = rss(lq_image, dim=1) + # breakpoint() + # io.imsave('lq_image.png', lq_image[0].numpy().astype(np.uint8)) + lq_image_list=[] + mean_list=[] + std_list=[] + for i in range(lq_image.shape[0]): + image, mean, std = normalize_instance(lq_image[i], eps=1e-11) + image = image.clamp(-6, 6) + lq_image_list.append(image) + mean_list.append(mean) + std_list.append(std) + + target = ifft2c(slice_kspace) + target = complex_center_crop(target, crop_size) + target = complex_abs(target) + target = rss(target, dim=1) + # io.imsave('target1.png', target[10].numpy().astype(np.uint8)) + # breakpoint() + target_list=[] + + for i in range(lq_image.shape[0]): + target_slice = normalize(target[i], mean_list[i], std_list[i], eps=1e-11) + target_slice = target_slice.clamp(-6, 6) + target_list.append(target_slice) + + return torch.stack(lq_image_list), torch.stack(target_list), torch.stack(mean_list), torch.stack(std_list) + + + + +class M4Raw_TrainSet(Dataset): + def __init__(self, args): + # mask_func = create_mask_for_mask_type( + # args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + # ) + + self._MRIDOWN = args.MRIDOWN + self.input_normalize = args.input_normalize + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_train', '*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_train', '*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + T2_input_list = [input_list1, input_list2, input_list3] + + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 240, 240]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_mean = np.zeros([len(input_list2),len(T2_input_list), 18]) + self.T2_std = np.zeros([len(input_list2),len(T2_input_list), 18]) + + print('TrainSet loading...') + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + _, self.T1_images[j][i], _, _ = read_h5(path, self._MRIDOWN) + + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_images[j][i], self.T2_masked_images[j][i], self.T2_mean[j][i], self.T2_std[j][i] = read_h5(path, self._MRIDOWN) + + + + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print('Finish loading') + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),240,240) + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),240,240) + self.T1_labels = self.T1_labels.reshape(-1,1,240,240) + self.T2_labels = self.T2_labels.reshape(-1,1,240,240) + self.T2_mean = self.T2_mean.reshape(-1,3) + self.T2_std = self.T2_std.reshape(-1,3) + print("Train data shape:", self.T1_images.shape) + # breakpoint() + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) ## 每次都是从三个repetition中选择一个作为input. + # choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + T1_images = T1_images[choices] + T2_images = T2_images[choices] + t2_mean = self.T2_mean[idx][choices] + t2_std = self.T2_std[idx][choices] + # breakpoint() + # import imageio as io + # io.imsave('T1_images.png', (T1_images[0]*255).astype(np.uint8)) + # io.imsave('T1_labels.png', (T1_labels[0]*255).astype(np.uint8)) + # io.imsave('T2_images.png', (T2_images[0]*255).astype(np.uint8)) + # io.imsave('T2_labels.png', (T2_labels[0]*255).astype(np.uint8)) + # breakpoint() + + t1_in=T1_images + t1=T1_labels + t2_in=T2_images + t2=T2_labels + + sample_stats = {"t2_mean": t2_mean, "t2_std": t2_std} + + + # breakpoint() + sample = {'image_in': t1_in, + 'image': t1, + + 'target_in': t2_in, + 'target': t2} + + return sample, sample_stats + + + +class M4Raw_TestSet(Dataset): + def __init__(self, args): + # mask_func = create_mask_for_mask_type( + # args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + # ) + self._MRIDOWN = args.MRIDOWN + self.input_normalize = args.input_normalize + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_val' + '*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_val', '*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + T2_input_list = [input_list1,input_list2,input_list3] + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 240, 240]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_mean = np.zeros([len(input_list2),len(T2_input_list), 18]) + self.T2_std = np.zeros([len(input_list2),len(T2_input_list), 18]) + print('TrainSet loading...') + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + _, self.T1_images[j][i], _, _ = read_h5(path, self._MRIDOWN) + + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_images[j][i], self.T2_masked_images[j][i], self.T2_mean[j][i], self.T2_std[j][i] = read_h5(path, self._MRIDOWN) + + + + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print('Finish loading') + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),240,240) + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),240,240) + self.T1_labels = self.T1_labels.reshape(-1,1,240,240) + self.T2_labels = self.T2_labels.reshape(-1,1,240,240) + self.T2_mean = self.T2_mean.reshape(-1,3) + self.T2_std = self.T2_std.reshape(-1,3) + print("Train data shape:", self.T1_images.shape) + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + # choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) ## 每次都是从三个repetition中选择一个作为input. + choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + T1_images = T1_images[choices] + T2_images = T2_images[choices] + t2_mean = self.T2_mean[idx][choices] + t2_std = self.T2_std[idx][choices] + # breakpoint() + # import imageio as io + # io.imsave('T1_images.png', (T1_images[0]*255).astype(np.uint8)) + # io.imsave('T1_labels.png', (T1_labels[0]*255).astype(np.uint8)) + # io.imsave('T2_images.png', (T2_images[0]*255).astype(np.uint8)) + # io.imsave('T2_labels.png', (T2_labels[0]*255).astype(np.uint8)) + # breakpoint() + + t1_in=T1_images + t1=T1_labels + t2_in=T2_images + t2=T2_labels + + sample_stats = {"t2_mean": t2_mean, "t2_std": t2_std} + + + # breakpoint() + sample = {'image_in': t1_in, + 'image': t1, + + 'target_in': t2_in, + 'target': t2} + + return sample, sample_stats + + + +def compute_metrics(image, labels): + MSE = mean_squared_error(labels, image)/np.var(labels) + PSNR = peak_signal_noise_ratio(labels, image) + SSIM = structural_similarity(labels, image) + + # print("metrics:", MSE, PSNR, SSIM) + + return MSE, PSNR, SSIM + +def complex_abs_eval(data): + return (data[0:1, :, :] ** 2 + data[1:2, :, :] ** 2).sqrt() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--root_path', type=str, default='/data/qic99/MRI_recon/') + parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') + parser.add_argument('--kspace_refine', type=str, default='False', help='whether use the image reconstructed from kspace network.') + args = parser.parse_args() + + db_test = M4Raw_TestSet(args) + testloader = DataLoader(db_test, batch_size=4, shuffle=False, num_workers=4, pin_memory=True) + + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all = [], [], [] + + save_dir = "./visualize_images/" + + for i_batch, sampled_batch in enumerate(testloader): + # t1_in, t2_in = sampled_batch['t1_in'].cuda(), sampled_batch['t2_in'].cuda() + # t1, t2 = sampled_batch['t1_labels'].cuda(), sampled_batch['t2_labels'].cuda() + + t1_in, t2_in = sampled_batch['ref_image_sub'].cuda(), sampled_batch['tag_image_sub'].cuda() + t1, t2 = sampled_batch['ref_image_full'].cuda(), sampled_batch['tag_image_full'].cuda() + + # breakpoint() + for j in range(t1_in.shape[0]): + # t1_in_img = (np.clip(complex_abs_eval(t1_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(complex_abs_eval(t1[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(complex_abs_eval(t2_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(complex_abs_eval(t2[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # breakpoint() + t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + # t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + + # print(t1_in_img.shape, t1_img.shape) + t1_MSE, t1_PSNR, t1_SSIM = compute_metrics(t1_in_img, t1_img) + t2_MSE, t2_PSNR, t2_SSIM = compute_metrics(t2_in_img, t2_img) + + + t1_MSE_all.append(t1_MSE) + t1_PSNR_all.append(t1_PSNR) + t1_SSIM_all.append(t1_SSIM) + + t2_MSE_all.append(t2_MSE) + t2_PSNR_all.append(t2_PSNR) + t2_SSIM_all.append(t2_SSIM) + + + # print("t1_PSNR:", t1_PSNR_all) + print("t1_PSNR:", round(np.array(t1_PSNR_all).mean(), 4)) + print("t1_NMSE:", round(np.array(t1_MSE_all).mean(), 4)) + print("t1_SSIM:", round(np.array(t1_SSIM_all).mean(), 4)) + + print("t2_PSNR:", round(np.array(t2_PSNR_all).mean(), 4)) + print("t2_NMSE:", round(np.array(t2_MSE_all).mean(), 4)) + print("t2_SSIM:", round(np.array(t2_SSIM_all).mean(), 4)) diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/m4raw_std_dataloader.py b/MRI_recon/code/Frequency-Diffusion/dataset/m4raw_std_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..c1ee7835f421a22ed9a8514884bc95e1498dc378 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/m4raw_std_dataloader.py @@ -0,0 +1,487 @@ + +from __future__ import print_function, division +from typing import Dict, NamedTuple, Optional, Sequence, Tuple, Union +import sys +sys.path.append('.') +from glob import glob +import os +os.environ['OPENBLAS_NUM_THREADS'] = '1' +import numpy as np +import torch +from torch.utils.data import Dataset + +import h5py +from matplotlib import pyplot as plt +from dataloaders.m4_utils import ifft2c, fft2c, complex_abs +from dataloaders.kspace_subsample import create_mask_for_mask_type + +import argparse +from torch.utils.data import DataLoader +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + +# def normal(x): +# y = np.zeros_like(x) +# for i in range(y.shape[0]): +# x_min = x[i].min() +# x_max = x[i].max() +# y[i] = (x[i] - x_min)/(x_max-x_min) +# return y + + + +# def undersample_mri(kspace, _MRIDOWN): +# # print("kspace shape:", kspace.shape) ## [18, 4, 256, 256, 2] + +# if _MRIDOWN == "4X": +# mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 +# elif _MRIDOWN == "8X": +# mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + +# ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + +# shape = [256, 256, 1] +# mask = ff(shape, seed=1337) ## [1, 256, 1] + +# mask = mask[:, :, 0] # [1, 256] + +# masked_kspace = kspace * mask[None, None, :, :, None] + +# return masked_kspace, mask.unsqueeze(-1) + +def apply_mask(data, mask_func, seed=None, padding=None): + """ + Subsample given k-space by multiplying with a mask. + + Args: + data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where + dimensions -3 and -2 are the spatial dimensions, and the final dimension has size + 2 (for complex values). + mask_func (callable): A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed (int or 1-d array_like, optional): Seed for the random number generator. + + Returns: + (tuple): tuple containing: + masked data (torch.Tensor): Subsampled k-space data + mask (torch.Tensor): The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + +def complex_center_crop(data, shape): + """ + Apply a center crop to the input image or batch of complex images. + + Args: + data (torch.Tensor): The complex input tensor to be center cropped. It + should have at least 3 dimensions and the cropping is applied along + dimensions -3 and -2 and the last dimensions should have a size of + 2. + shape (int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image + """ + assert 0 < shape[0] <= data.shape[-3] + assert 0 < shape[1] <= data.shape[-2] + + w_from = (data.shape[-3] - shape[0]) // 2 #80 + h_from = (data.shape[-2] - shape[1]) // 2 #80 + w_to = w_from + shape[0] #240 + h_to = h_from + shape[1] #240 + + return data[..., w_from:w_to, h_from:h_to, :] + + + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Args: + data (np.array): Input numpy array. + + Returns: + torch.Tensor: PyTorch version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data) + +def rss(data, dim=0): + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Args: + data (torch.Tensor): The input tensor + dim (int): The dimensions along which to apply the RSS transform + + Returns: + torch.Tensor: The RSS value. + """ + return torch.sqrt((data ** 2).sum(dim)) + +def read_h5(file_name, mask_func): + crop_size=[240,240] + + hf = h5py.File(file_name) + volume_kspace = hf['kspace'][()] + + slice_kspace = volume_kspace + slice_kspace = to_tensor(slice_kspace) + import imageio as io + + + + masked_kspace, mask = apply_mask(slice_kspace, mask_func, seed=123456) + lq_image = ifft2c(masked_kspace) + lq_image = complex_center_crop(lq_image, crop_size) + lq_image = complex_abs(lq_image) + lq_image = rss(lq_image, dim=1) + # breakpoint() + # io.imsave('lq_image.png', lq_image[0].numpy().astype(np.uint8)) + lq_image_list=[] + mean_list=[] + std_list=[] + for i in range(lq_image.shape[0]): + image, mean, std = normalize_instance(lq_image[i], eps=1e-11) + image = image.clamp(-6, 6) + lq_image_list.append(image) + mean_list.append(mean) + std_list.append(std) + + target = ifft2c(slice_kspace) + target = complex_center_crop(target, crop_size) + target = complex_abs(target) + target = rss(target, dim=1) + # io.imsave('target1.png', target[10].numpy().astype(np.uint8)) + # breakpoint() + target_list=[] + + for i in range(lq_image.shape[0]): + target_slice = normalize(target[i], mean_list[i], std_list[i], eps=1e-11) + target_slice = target_slice.clamp(-6, 6) + target_list.append(target_slice) + + return torch.stack(lq_image_list), torch.stack(target_list), torch.stack(mean_list), torch.stack(std_list) + + + + +class M4Raw_TrainSet(Dataset): + def __init__(self, args): + mask_func = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + + self._MRIDOWN = args.MRIDOWN + self.input_normalize = args.input_normalize + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_train', '*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_train', '*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + T2_input_list = [input_list1, input_list2, input_list3] + + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 240, 240]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_mean = np.zeros([len(input_list2),len(T2_input_list), 18]) + self.T2_std = np.zeros([len(input_list2),len(T2_input_list), 18]) + + print('TrainSet loading...') + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + _, self.T1_images[j][i], _, _ = read_h5(path, mask_func) + + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_images[j][i], self.T2_masked_images[j][i], self.T2_mean[j][i], self.T2_std[j][i] = read_h5(path, mask_func) + + + + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print('Finish loading') + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),240,240) + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),240,240) + self.T1_labels = self.T1_labels.reshape(-1,1,240,240) + self.T2_labels = self.T2_labels.reshape(-1,1,240,240) + self.T2_mean = self.T2_mean.reshape(-1,3) + self.T2_std = self.T2_std.reshape(-1,3) + print("Train data shape:", self.T1_images.shape) + # breakpoint() + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) ## 每次都是从三个repetition中选择一个作为input. + # choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + T1_images = T1_images[choices] + T2_images = T2_images[choices] + t2_mean = self.T2_mean[idx][choices] + t2_std = self.T2_std[idx][choices] + # breakpoint() + # import imageio as io + # io.imsave('T1_images.png', (T1_images[0]*255).astype(np.uint8)) + # io.imsave('T1_labels.png', (T1_labels[0]*255).astype(np.uint8)) + # io.imsave('T2_images.png', (T2_images[0]*255).astype(np.uint8)) + # io.imsave('T2_labels.png', (T2_labels[0]*255).astype(np.uint8)) + # breakpoint() + + t1_in=T1_images + t1=T1_labels + t2_in=T2_images + t2=T2_labels + + sample_stats = {"t2_mean": t2_mean, "t2_std": t2_std} + + + # breakpoint() + sample = {'image_in': t1_in, + 'image': t1, + + 'target_in': t2_in, + 'target': t2} + + return sample, sample_stats + + + +class M4Raw_TestSet(Dataset): + def __init__(self, args): + mask_func = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + + self.input_normalize = args.input_normalize + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_val' + '*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_val', '*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + T2_input_list = [input_list1,input_list2,input_list3] + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 240, 240]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_mean = np.zeros([len(input_list2),len(T2_input_list), 18]) + self.T2_std = np.zeros([len(input_list2),len(T2_input_list), 18]) + print('TrainSet loading...') + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + _, self.T1_images[j][i], _, _ = read_h5(path, mask_func) + + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_images[j][i], self.T2_masked_images[j][i], self.T2_mean[j][i], self.T2_std[j][i] = read_h5(path, mask_func) + + + + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print('Finish loading') + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),240,240) + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),240,240) + self.T1_labels = self.T1_labels.reshape(-1,1,240,240) + self.T2_labels = self.T2_labels.reshape(-1,1,240,240) + self.T2_mean = self.T2_mean.reshape(-1,3) + self.T2_std = self.T2_std.reshape(-1,3) + print("Train data shape:", self.T1_images.shape) + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + # choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) ## 每次都是从三个repetition中选择一个作为input. + choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + T1_images = T1_images[choices] + T2_images = T2_images[choices] + t2_mean = self.T2_mean[idx][choices] + t2_std = self.T2_std[idx][choices] + # breakpoint() + # import imageio as io + # io.imsave('T1_images.png', (T1_images[0]*255).astype(np.uint8)) + # io.imsave('T1_labels.png', (T1_labels[0]*255).astype(np.uint8)) + # io.imsave('T2_images.png', (T2_images[0]*255).astype(np.uint8)) + # io.imsave('T2_labels.png', (T2_labels[0]*255).astype(np.uint8)) + # breakpoint() + + t1_in=T1_images + t1=T1_labels + t2_in=T2_images + t2=T2_labels + + sample_stats = {"t2_mean": t2_mean, "t2_std": t2_std} + + + # breakpoint() + sample = {'image_in': t1_in, + 'image': t1, + + 'target_in': t2_in, + 'target': t2} + + return sample, sample_stats + + + +def compute_metrics(image, labels): + MSE = mean_squared_error(labels, image)/np.var(labels) + PSNR = peak_signal_noise_ratio(labels, image) + SSIM = structural_similarity(labels, image) + + # print("metrics:", MSE, PSNR, SSIM) + + return MSE, PSNR, SSIM + +def complex_abs_eval(data): + return (data[0:1, :, :] ** 2 + data[1:2, :, :] ** 2).sqrt() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--root_path', type=str, default='/data/qic99/MRI_recon/') + parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') + parser.add_argument('--kspace_refine', type=str, default='False', help='whether use the image reconstructed from kspace network.') + args = parser.parse_args() + + db_test = M4Raw_TestSet(args) + testloader = DataLoader(db_test, batch_size=4, shuffle=False, num_workers=4, pin_memory=True) + + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all = [], [], [] + + save_dir = "./visualize_images/" + + for i_batch, sampled_batch in enumerate(testloader): + # t1_in, t2_in = sampled_batch['t1_in'].cuda(), sampled_batch['t2_in'].cuda() + # t1, t2 = sampled_batch['t1_labels'].cuda(), sampled_batch['t2_labels'].cuda() + + t1_in, t2_in = sampled_batch['ref_image_sub'].cuda(), sampled_batch['tag_image_sub'].cuda() + t1, t2 = sampled_batch['ref_image_full'].cuda(), sampled_batch['tag_image_full'].cuda() + + # breakpoint() + for j in range(t1_in.shape[0]): + # t1_in_img = (np.clip(complex_abs_eval(t1_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(complex_abs_eval(t1[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(complex_abs_eval(t2_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(complex_abs_eval(t2[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # breakpoint() + t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + # t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + + # print(t1_in_img.shape, t1_img.shape) + t1_MSE, t1_PSNR, t1_SSIM = compute_metrics(t1_in_img, t1_img) + t2_MSE, t2_PSNR, t2_SSIM = compute_metrics(t2_in_img, t2_img) + + + t1_MSE_all.append(t1_MSE) + t1_PSNR_all.append(t1_PSNR) + t1_SSIM_all.append(t1_SSIM) + + t2_MSE_all.append(t2_MSE) + t2_PSNR_all.append(t2_PSNR) + t2_SSIM_all.append(t2_SSIM) + + + # print("t1_PSNR:", t1_PSNR_all) + print("t1_PSNR:", round(np.array(t1_PSNR_all).mean(), 4)) + print("t1_NMSE:", round(np.array(t1_MSE_all).mean(), 4)) + print("t1_SSIM:", round(np.array(t1_SSIM_all).mean(), 4)) + + print("t2_PSNR:", round(np.array(t2_PSNR_all).mean(), 4)) + print("t2_NMSE:", round(np.array(t2_MSE_all).mean(), 4)) + print("t2_SSIM:", round(np.array(t2_SSIM_all).mean(), 4)) diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/utils/abd_dataset_utils.py b/MRI_recon/code/Frequency-Diffusion/dataset/utils/abd_dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..25827ddeb9bb48fa5680b87d111b841ad2ebb892 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/utils/abd_dataset_utils.py @@ -0,0 +1,65 @@ +""" +Utils for datasets +""" +import numpy as np +import os +import sys +import numpy as np +import pdb +import SimpleITK as sitk +from .niftiio import read_nii_bysitk + + +def get_normalize_op(modality, fids): + """ + As title + Args: + modality: CT or MR + fids: fids for the fold + """ + + def get_CT_statistics(scan_fids): + """ + As CT are quantitative, get mean and std for CT images for image normalizing + As in reality we might not be able to load all images at a time, we would better detach statistics calculation with actual data loading + However, in unseen dataset we have no clues about the data statistics at all so just normalize each 3D image to zero mean unit variance + """ + total_val = 0 + n_pix = 0 + for fid in scan_fids: + in_img = read_nii_bysitk(fid) + total_val += in_img.sum() + n_pix += np.prod(in_img.shape) + del in_img + meanval = total_val / n_pix + + total_var = 0 + for fid in scan_fids: + in_img = read_nii_bysitk(fid) + total_var += np.sum((in_img - meanval) ** 2 ) + del in_img + var_all = total_var / n_pix + + global_std = var_all ** 0.5 + + return meanval, global_std + + + if modality == 'SABSCT': + ct_mean, ct_std = get_CT_statistics(fids) + + def CT_normalize(x_in): + """ + Normalizing CT images, based on global statistics + """ + return x_in, ct_mean, ct_std + + return CT_normalize #, {'mean': ct_mean, 'std': ct_std} + + else: # modality == 'CHAOST2' : + + def MR_normalize(x_in): + return x_in, x_in.mean(), x_in.std() + + return MR_normalize #, {'mean': None, 'std': None} # we do not really need the global statistics for MR + diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/utils/image_transforms.py b/MRI_recon/code/Frequency-Diffusion/dataset/utils/image_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..277bdb02878221816c6a69720a7c98c41bbd2dcb --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/utils/image_transforms.py @@ -0,0 +1,319 @@ +""" +Image transforms functions for data augmentation +Credit to Dr. Jo Schlemper +""" +try: + from collections import Sequence +except: + from collections.abc import Sequence +import cv2 +import numpy as np +import scipy +from scipy.ndimage.filters import gaussian_filter +from scipy.ndimage.interpolation import map_coordinates +from numpy.lib.stride_tricks import as_strided + +###### UTILITIES ###### +def random_num_generator(config, random_state=np.random): + if config[0] == 'uniform': + ret = random_state.uniform(config[1], config[2], 1)[0] + elif config[0] == 'lognormal': + ret = random_state.lognormal(config[1], config[2], 1)[0] + else: + #print(config) + raise Exception('unsupported format') + return ret + +def get_translation_matrix(translation): + """ translation: [tx, ty] """ + tx, ty = translation + translation_matrix = np.array([[1, 0, tx], + [0, 1, ty], + [0, 0, 1]]) + return translation_matrix + + + +def get_rotation_matrix(rotation, input_shape, centred=True): + theta = np.pi / 180 * np.array(rotation) + if centred: + rotation_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), rotation, 1) + rotation_matrix = np.vstack([rotation_matrix, [0, 0, 1]]) + else: + rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1]]) + return rotation_matrix + +def get_zoom_matrix(zoom, input_shape, centred=True): + zx, zy = zoom + if centred: + zoom_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), 0, zoom[0]) + zoom_matrix = np.vstack([zoom_matrix, [0, 0, 1]]) + else: + zoom_matrix = np.array([[zx, 0, 0], + [0, zy, 0], + [0, 0, 1]]) + return zoom_matrix + +def get_shear_matrix(shear_angle): + theta = (np.pi * shear_angle) / 180 + shear_matrix = np.array([[1, -np.sin(theta), 0], + [0, np.cos(theta), 0], + [0, 0, 1]]) + return shear_matrix + +###### AFFINE TRANSFORM ###### +class RandomAffine(object): + """Apply random affine transformation on a numpy.ndarray (H x W x C) + Comment by co1818: this is still doing affine on 2d (H x W plane). + A same transform is applied to all C channels + + Parameter: + ---------- + + alpha: Range [0, 4] seems good for small images + + order: interpolation method (c.f. opencv) + """ + + def __init__(self, + rotation_range=None, + translation_range=None, + shear_range=None, + zoom_range=None, + zoom_keep_aspect=False, + interp='bilinear', + order=3): + """ + Perform an affine transforms. + + Arguments + --------- + rotation_range : one integer or float + image will be rotated randomly between (-degrees, degrees) + + translation_range : (x_shift, y_shift) + shifts in pixels + + *NOT TESTED* shear_range : float + image will be sheared randomly between (-degrees, degrees) + + zoom_range : (zoom_min, zoom_max) + list/tuple with two floats between [0, infinity). + first float should be less than the second + lower and upper bounds on percent zoom. + Anything less than 1.0 will zoom in on the image, + anything greater than 1.0 will zoom out on the image. + e.g. (0.7, 1.0) will only zoom in, + (1.0, 1.4) will only zoom out, + (0.7, 1.4) will randomly zoom in or out + """ + + self.rotation_range = rotation_range + self.translation_range = translation_range + self.shear_range = shear_range + self.zoom_range = zoom_range + self.zoom_keep_aspect = zoom_keep_aspect + self.interp = interp + self.order = order + + def build_M(self, input_shape): + tfx = [] + final_tfx = np.eye(3) + if self.rotation_range: + rot = np.random.uniform(-self.rotation_range, self.rotation_range) + tfx.append(get_rotation_matrix(rot, input_shape)) + if self.translation_range: + tx = np.random.uniform(-self.translation_range[0], self.translation_range[0]) + ty = np.random.uniform(-self.translation_range[1], self.translation_range[1]) + tfx.append(get_translation_matrix((tx,ty))) + if self.shear_range: + rot = np.random.uniform(-self.shear_range, self.shear_range) + tfx.append(get_shear_matrix(rot)) + if self.zoom_range: + sx = np.random.uniform(self.zoom_range[0], self.zoom_range[1]) + if self.zoom_keep_aspect: + sy = sx + else: + sy = np.random.uniform(self.zoom_range[0], self.zoom_range[1]) + + tfx.append(get_zoom_matrix((sx, sy), input_shape)) + + for tfx_mat in tfx: + final_tfx = np.dot(tfx_mat, final_tfx) + + return final_tfx.astype(np.float32) + + def __call__(self, image): + # build matrix + input_shape = image.shape[:2] + M = self.build_M(input_shape) + + res = np.zeros_like(image) + #if isinstance(self.interp, Sequence): + if type(self.order) is list or type(self.order) is tuple: + for i, intp in enumerate(self.order): + res[..., i] = affine_transform_via_M(image[..., i], M[:2], interp=intp) + else: + # squeeze if needed + orig_shape = image.shape + image_s = np.squeeze(image) + res = affine_transform_via_M(image_s, M[:2], interp=self.order) + res = res.reshape(orig_shape) + + #res = affine_transform_via_M(image, M[:2], interp=self.order) + + return res + +def affine_transform_via_M(image, M, borderMode=cv2.BORDER_CONSTANT, interp=cv2.INTER_NEAREST): + imshape = image.shape + shape_size = imshape[:2] + + # Random affine + warped = cv2.warpAffine(image.reshape(shape_size + (-1,)), M, shape_size[::-1], + flags=interp, borderMode=borderMode) + + #print(imshape, warped.shape) + + warped = warped[..., np.newaxis].reshape(imshape) + + return warped + +###### ELASTIC TRANSFORM ###### +def elastic_transform(image, alpha=1000, sigma=30, spline_order=1, mode='nearest', random_state=np.random): + """Elastic deformation of image as described in [Simard2003]_. + .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for + Convolutional Neural Networks applied to Visual Document Analysis", in + Proc. of the International Conference on Document Analysis and + Recognition, 2003. + """ + assert image.ndim == 3 + shape = image.shape[:2] + + dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), + sigma, mode="constant", cval=0) * alpha + dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), + sigma, mode="constant", cval=0) * alpha + + x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij') + indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))] + result = np.empty_like(image) + for i in range(image.shape[2]): + result[:, :, i] = map_coordinates( + image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape) + return result + + +def elastic_transform_nd(image, alpha, sigma, random_state=None, order=1, lazy=False): + """Expects data to be (nx, ny, n1 ,..., nm) + params: + ------ + + alpha: + the scaling parameter. + E.g.: alpha=2 => distorts images up to 2x scaling + + sigma: + standard deviation of gaussian filter. + E.g. + low (sig~=1e-3) => no smoothing, pixelated. + high (1/5 * imsize) => smooth, more like affine. + very high (1/2*im_size) => translation + """ + + if random_state is None: + random_state = np.random.RandomState(None) + + shape = image.shape + imsize = shape[:2] + dim = shape[2:] + + # Random affine + blur_size = int(4*sigma) | 1 + dx = cv2.GaussianBlur(random_state.rand(*imsize)*2-1, + ksize=(blur_size, blur_size), sigmaX=sigma) * alpha + dy = cv2.GaussianBlur(random_state.rand(*imsize)*2-1, + ksize=(blur_size, blur_size), sigmaX=sigma) * alpha + + # use as_strided to copy things over across n1...nn channels + dx = as_strided(dx.astype(np.float32), + strides=(0,) * len(dim) + (4*shape[1], 4), + shape=dim+(shape[0], shape[1])) + dx = np.transpose(dx, axes=(-2, -1) + tuple(range(len(dim)))) + + dy = as_strided(dy.astype(np.float32), + strides=(0,) * len(dim) + (4*shape[1], 4), + shape=dim+(shape[0], shape[1])) + dy = np.transpose(dy, axes=(-2, -1) + tuple(range(len(dim)))) + + coord = np.meshgrid(*[np.arange(shape_i) for shape_i in (shape[1], shape[0]) + dim]) + indices = [np.reshape(e+de, (-1, 1)) for e, de in zip([coord[1], coord[0]] + coord[2:], + [dy, dx] + [0] * len(dim))] + + if lazy: + return indices + + return map_coordinates(image, indices, order=order, mode='reflect').reshape(shape) + +class ElasticTransform(object): + """Apply elastic transformation on a numpy.ndarray (H x W x C) + """ + + def __init__(self, alpha, sigma, order=1): + self.alpha = alpha + self.sigma = sigma + self.order = order + + def __call__(self, image): + if isinstance(self.alpha, Sequence): + alpha = random_num_generator(self.alpha) + else: + alpha = self.alpha + if isinstance(self.sigma, Sequence): + sigma = random_num_generator(self.sigma) + else: + sigma = self.sigma + return elastic_transform_nd(image, alpha=alpha, sigma=sigma, order=self.order) + +class RandomFlip3D(object): + + def __init__(self, h=True, v=True, t=True, p=0.5): + """ + Randomly flip an image horizontally and/or vertically with + some probability. + + Arguments + --------- + h : boolean + whether to horizontally flip w/ probability p + + v : boolean + whether to vertically flip w/ probability p + + p : float between [0,1] + probability with which to apply allowed flipping operations + """ + self.horizontal = h + self.vertical = v + self.depth = t + self.p = p + + def __call__(self, x, y=None): + # horizontal flip with p = self.p + if self.horizontal: + if np.random.random() < self.p: + x = x[::-1, ...] + + # vertical flip with p = self.p + if self.vertical: + if np.random.random() < self.p: + x = x[:, ::-1, ...] + + if self.depth: + if np.random.random() < self.p: + x = x[..., ::-1] + + return x + + diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/utils/math.py b/MRI_recon/code/Frequency-Diffusion/dataset/utils/math.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0d76246b0c8fb757b6b589b7f889e0316696d0 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/utils/math.py @@ -0,0 +1,272 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +import numpy as np + +def complex_mul(x, y): + """ + Complex multiplication. + + This multiplies two complex tensors assuming that they are both stored as + real arrays with the last dimension being the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == y.shape[-1] == 2 + re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] + im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] + + return torch.stack((re, im), dim=-1) + + +def complex_conj(x): + """ + Complex conjugate. + + This applies the complex conjugate assuming that the input array has the + last dimension as the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == 2 + + return torch.stack((x[..., 0], -x[..., 1]), dim=-1) + + +# def fft2c(data): +# """ +# Apply centered 2 dimensional Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The FFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# data = torch.fft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + +# def ifft2c(data): +# """ +# Apply centered 2-dimensional Inverse Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The IFFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# # data = torch.ifft(data, 2, normalized=True) +# data = torch.fft.ifft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + + +def fft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2 dimensional Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.fft``. + + Returns: + The FFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.fftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.ifft``. + + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + + + + +def complex_abs(data): + """ + Compute the absolute value of a complex valued input tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Absolute value of data. + """ + assert data.size(-1) == 2 + + return (data ** 2).sum(dim=-1).sqrt() + + + +def complex_abs_numpy(data): + assert data.shape[-1] == 2 + + return np.sqrt(np.sum(data ** 2, axis=-1)) + + +def complex_abs_sq(data):#multi coil + """ + Compute the squared absolute value of a complex tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Squared absolute value of data. + """ + assert data.size(-1) == 2 + return (data ** 2).sum(dim=-1) + + +# Helper functions + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data + """ + data = data.numpy() + return data[..., 0] + 1j * data[..., 1] diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/utils/niftiio.py b/MRI_recon/code/Frequency-Diffusion/dataset/utils/niftiio.py new file mode 100644 index 0000000000000000000000000000000000000000..19fce7bc59793d6c2711b497ee01577433788172 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/utils/niftiio.py @@ -0,0 +1,47 @@ +""" +Utils for datasets +""" +import numpy as np +import numpy as np +import SimpleITK as sitk + + +def read_nii_bysitk(input_fid, peel_info = False): + """ read nii to numpy through simpleitk + peelinfo: taking direction, origin, spacing and metadata out + """ + img_obj = sitk.ReadImage(input_fid) + img_np = sitk.GetArrayFromImage(img_obj) + if peel_info: + info_obj = { + "spacing": img_obj.GetSpacing(), + "origin": img_obj.GetOrigin(), + "direction": img_obj.GetDirection(), + "array_size": img_np.shape + } + return img_np, info_obj + else: + return img_np + +def convert_to_sitk(input_mat, peeled_info): + """ + write a numpy array to sitk image object with essential meta-data + """ + nii_obj = sitk.GetImageFromArray(input_mat) + if peeled_info: + nii_obj.SetSpacing( peeled_info["spacing"] ) + nii_obj.SetOrigin( peeled_info["origin"] ) + nii_obj.SetDirection(peeled_info["direction"] ) + return nii_obj + +def np2itk(img, ref_obj): + """ + img: numpy array + ref_obj: reference sitk object for copying information from + """ + itk_obj = sitk.GetImageFromArray(img) + itk_obj.SetSpacing( ref_obj.GetSpacing() ) + itk_obj.SetOrigin( ref_obj.GetOrigin() ) + itk_obj.SetDirection( ref_obj.GetDirection() ) + return itk_obj + diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/utils/subsample.py b/MRI_recon/code/Frequency-Diffusion/dataset/utils/subsample.py new file mode 100644 index 0000000000000000000000000000000000000000..d0620da3414c6077e4293376fb8a9be01ad19990 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/utils/subsample.py @@ -0,0 +1,195 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/utils/transform_albu.py b/MRI_recon/code/Frequency-Diffusion/dataset/utils/transform_albu.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac07fc3237abbf310cd0b088f5b49e1cd042735 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/utils/transform_albu.py @@ -0,0 +1,125 @@ +# -*- encoding: utf-8 -*- +#Time :2022/02/24 18:14:15 +#Author :Hao Chen +#FileName :trans_lib.py +#Version :2.0 + +import cv2 +import torch +import numpy as np +import albumentations as A +def gaussian_noise(img, mean, sigma): + return img + torch.FloatTensor(img.shape).normal_(mean=mean, std=sigma) +# from albumentations.pytorch import ShiftScaleRotate + +def GammaInterference(img): + # Shape Span + gamma = np.random.random() * 1.5 + 0.25 # 0.25 ~ 1.75 + # gamma = np.random.random() * 1.75 + 0.25 # 0.25 ~ 1.75 + img = gamma_concern(img, gamma) # concerntrate + + # Shape Tilt + choose = np.random.randint(0, 2) + direction = np.random.randint(0, 2) + + if choose == 0: + gamma = 0.2 + np.random.random() * 2.3 # 2.5 + img = gamma_power(img, gamma, direction) + else: + gamma = np.random.random() * 2.3 + 0.6 # 1.5 center + img = gamma_exp(img, gamma, direction) + + return img + + + +def get_resize_transforms(img_size = (192, 192)): + # if type == 'train': + return A.Compose([ + A.Resize(img_size[0], img_size[1]) + ], p=1.0, additional_targets={'image2': 'image', "mask2": "mask"}) + + +def get_albu_transforms(type="train", img_size = (192, 192)): + if type == 'train': + compose = [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.ShiftScaleRotate(shift_limit=0.2, scale_limit=(-0.2, 0.2), + rotate_limit=5, p=0.5), + + # A.Defocus(radius=(4, 8), alias_blur=(0.2, 0.4), p=0.5), + # A.GaussNoise(var_limit=(10.0, 25.0), p=0.5), + + # A.GaussianBlur(blur_limit=(3, 7), p=0.5), + # A.Emboss(alpha=(0.5, 1.0), strength=(0.5, 1.0), p=0.5), # Added + + # A.FDA([target_image], p=1, read_fn=lambda x: x) + # A.PixelDistributionAdaptation( reference_images=[reference_image], + + # A.Defocus(radius=(4, 8), alias_blur=(0.2, 0.4), p=0.5) + + # Randomly posterize between 2 and 5 bits + # A.Posterize(num_bits=(4, 6), p=0.5), + + # A.OneOf([ + # A.RandomShadow(p=1.0), + # A.Solarize(p=1.0), + # A.RandomSunFlare(p=1.0), + # ], p=0.5), + + # A.Saturation + # A.HueSaturationValue(hue_shift_limit=0, sat_shift_limit=0, val_shift_limit=5, p=0.5), + # A.RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), + # contrast_limit=(-0.1, 0.1), p=0.5), + # A.MaskDropout(p=0.5), + + A.OneOf([ + A.GridDistortion(num_steps=1, distort_limit=0.3, p=1.0), + A.ElasticTransform(alpha=3, sigma=15, alpha_affine=10, p=1.0) + ], p=0.5), + + A.Resize(img_size[0], img_size[1])] + else: + compose = [A.Resize(img_size[0], img_size[1])] + + return A.Compose(compose, p=1.0, additional_targets={'image2': 'image', "mask2": "mask"}) + + + + +# Beta function +def gamma_concern(img, gamma): + mean = torch.mean(img) + + img = (img - mean) * gamma + img = img + mean + img = torch.clip(img, 0, 1) + + return img + +def gamma_power(img, gamma, direction=0): + if direction == 1: + img = 1 - img + img = torch.pow(img, gamma) + + img = img / torch.max(img) + if direction == 1: + img = 1 - img + + return img + +def gamma_exp(img, gamma, direction=0): + if direction == 1: + img = 1 - img + + img = torch.exp(img * gamma) + img = img / torch.max(img) + + if direction == 1: + img = 1 - img + return img + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/utils/transform_utils.py b/MRI_recon/code/Frequency-Diffusion/dataset/utils/transform_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..96f438903c6d476a8150ba1f8d1fe192e00cf5a5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/utils/transform_utils.py @@ -0,0 +1,245 @@ +""" +Utilities for image transforms, part of the code base credits to Dr. Jo Schlemper +""" +from os.path import join +import torch +import numpy as np +import torchvision.transforms as deftfx +from . import image_transforms as myit +import copy +import math +from torchvision.transforms.functional import rotate as torchrotate +from torchvision.transforms.functional import InterpolationMode + +my_augv = { +'flip' : { 'v':False, 'h':False, 't': False, 'p':0.25 }, +'affine' : { + 'rotate':20, + 'shift':(15,15), + 'shear': 20, + 'scale':(0.5, 1.5), +}, +'elastic' : {'alpha':0,'sigma':0}, # medium {'alpha':20,'sigma':5}, +'reduce_2d': True, +'gamma_range': (1.0, 1.0 ), #(0.2, 1.8), +'noise' : { + 'noise_std': 0, # 0.15 + 'clip_pm1': False + }, +'bright_contrast': { + 'contrast': (1.0, 1.0), #(0.60, 1.5), + 'bright': (0, 0)#(-10, 10) + } +} + +tr_aug = { + 'aug': my_augv +} + + +def get_contrast_example(image, random_angle=0, flip=0): + if flip == [3]: + flip = [1, 2] + + # [..., H, W] + image_rotate = torchrotate(image, random_angle, + interpolation=InterpolationMode.BILINEAR) # Bilinear + image_rotate = torch.flip(image_rotate, flip) + + return image_rotate + + + +def get_geometric_transformer(aug, order=3): + affine = aug['aug'].get('affine', 0) + alpha = aug['aug'].get('elastic',{'alpha': 0})['alpha'] + sigma = aug['aug'].get('elastic',{'sigma': 0})['sigma'] + flip = aug['aug'].get('flip', {'v': True, 'h': True, 't': True, 'p':0.125}) + + tfx = [] + if 'flip' in aug['aug']: + tfx.append(myit.RandomFlip3D(**flip)) + + if 'affine' in aug['aug']: + tfx.append(myit.RandomAffine(affine.get('rotate'), + affine.get('shift'), + affine.get('shear'), + affine.get('scale'), + affine.get('scale_iso',True), + order=order)) + + if 'elastic' in aug['aug']: + tfx.append(myit.ElasticTransform(alpha, sigma)) + + input_transform = deftfx.Compose(tfx) + return input_transform + +def get_intensity_transformer(aug): + + def gamma_tansform(img): + gamma_range = aug['aug']['gamma_range'] + if isinstance(gamma_range, tuple): + gamma = np.random.rand() * (gamma_range[1] - gamma_range[0]) + gamma_range[0] + cmin = img.min() + irange = (img.max() - cmin + 1e-5) + + img = img - cmin + 1e-5 + img = irange * np.power(img * 1.0 / irange, gamma) + img = img + cmin + + elif gamma_range == False: + pass + else: + raise ValueError("Cannot identify gamma transform range {}".format(gamma_range)) + return img + + def brightness_contrast(img): + ''' + Chaitanya,K. et al. Semi-Supervised and Task-Driven data augmentation,864in: International Conference on Information Processing in Medical Imaging,865Springer. pp. 29–41. + ''' + cmin, cmax = aug['aug']['bright_contrast']['contrast'] + bmin, bmax = aug['aug']['bright_contrast']['bright'] + c = np.random.rand() * (cmax - cmin) + cmin + b = np.random.rand() * (bmax - bmin) + bmin + img_mean = img.mean() + img = (img - img_mean) * c + img_mean + b + return img + + def zm_gaussian_noise(img): + """ + zero-mean gaussian noise + """ + noise_sigma = aug['aug']['noise']['noise_std'] + noise_vol = np.random.randn(*img.shape) * noise_sigma + img = img + noise_vol + + if aug['aug']['noise']['clip_pm1']: # if clip to plus-minus 1 + img = np.clip(img, -1.0, 1.0) + return img + + def compile_transform(img): + # bright contrast + if 'bright_contrast' in aug['aug'].keys(): + img = brightness_contrast(img) + + # gamma + if 'gamma_range' in aug['aug'].keys(): + img = gamma_tansform(img) + + # additive noise + if 'noise' in aug['aug'].keys(): + img = zm_gaussian_noise(img) + + return img + + return compile_transform + + +def transform_with_label(aug, add_pseudolabel = False): + """ + Doing image geometric transform + Proposed image to have the following configurations + [H x W x C + CL] + Where CL is the number of channels for the label. It is NOT a one-hot thing + """ + + geometric_tfx = get_geometric_transformer(aug) + intensity_tfx = get_intensity_transformer(aug) + + def transform(comp, c_label, c_img, c_sam, nclass, is_train, use_onehot = False): + """ + Args + comp: a numpy array with shape [H x W x C + c_label] + c_label: number of channels for a compact label. Note that the current version only supports 1 slice (H x W x 1) + nc_onehot: -1 for not using one-hot representation of mask. otherwise, specify number of classes in the label + is_train: whether this is the training set or not. If not, do not perform the geometric transform + """ + comp = copy.deepcopy(comp) + if (use_onehot is True) and (c_label != 1): + raise NotImplementedError("Only allow compact label, also the label can only be 2d") + assert c_img + c_sam + c_label == comp.shape[-1], "only allow single slice 2D label" + + if is_train is True: + _label = comp[..., c_img ] + _sam = np.expand_dims(comp[..., c_img+c_label], axis=-1) + # compact to onehot + _h_label = np.float32(np.arange( nclass ) == (_label[..., None]) ) + # print("h_label=", _h_label.shape) + # print("_sam=", _sam.shape) + + comp = np.concatenate( [comp[..., :c_img ], _h_label, _sam], -1 ) + comp = geometric_tfx(comp) + # round one_hot labels to 0 or 1 + t_label_h = comp[..., c_img : -c_sam] + t_label_h = np.rint(t_label_h) + t_img = comp[..., 0 : c_img ] + t_sam = np.rint(comp[..., -c_sam:]) + + # intensity transform + t_img = intensity_tfx(t_img) + + if use_onehot is True: + t_label = t_label_h + else: + t_label = np.expand_dims(np.argmax(t_label_h, axis = -1), -1) + return t_img, t_label, t_sam + + return transform + + + + +def gamma_concern(img, gamma): + mean = np.mean(img) + + img = (img - mean) * gamma + img = img + mean + img = np.clip(img, 0, 1) + + return img + +def gamma_power(img, gamma, direction=0): + if direction == 1: + img = 1 - img + img = np.power(img, gamma) + + img = img / np.max(img) + if direction == 1: + img = 1 - img + + return img + +def gamma_exp(img, gamma, direction=0): + if direction == 1: + img = 1 - img + + img = np.exp(img * gamma) + img = img / np.max(img) + + if direction == 1: + img = 1 - img + return img + + +def GammaInterference(img): + # Shape Span + gamma = np.random.random() * 1.5 + 0.25 # 0.25 ~ 1.75 + # gamma = np.random.random() * 1.75 + 0.25 # 0.25 ~ 1.75 + img = gamma_concern(img, gamma) # concerntrate + + # Shape Tilt + choose = np.random.randint(0, 2) + direction = np.random.randint(0, 2) + + if choose == 0: + gamma = 0.2 + np.random.random() * 2.3 # 2.5 + img = gamma_power(img, gamma, direction) + else: + gamma = np.random.random() * 2.3 + 0.6 # 1.5 center + img = gamma_exp(img, gamma, direction) + + return img + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/dataset/utils/transforms.py b/MRI_recon/code/Frequency-Diffusion/dataset/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..ec304cf13a66ee181491abe7d9adbd31b16e3f4b --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/dataset/utils/transforms.py @@ -0,0 +1,487 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import numpy as np +import torch + +from dataset.m4_utils.math import ifft2c, fft2c, complex_abs +from dataset.m4_utils.subsample import create_mask_for_mask_type, MaskFunc +import random + +from typing import Dict, Optional, Sequence, Tuple, Union +from matplotlib import pyplot as plt +import os + +def rss(data, dim=0): + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Args: + data (torch.Tensor): The input tensor + dim (int): The dimensions along which to apply the RSS transform + + Returns: + torch.Tensor: The RSS value. + """ + return torch.sqrt((data ** 2).sum(dim)) + + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Args: + data (np.array): Input numpy array. + + Returns: + torch.Tensor: PyTorch version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data. + """ + data = data.numpy() + + return data[..., 0] + 1j * data[..., 1] + + +def apply_mask(data, mask_func, seed=None, padding=None): + """ + Subsample given k-space by multiplying with a mask. + + Args: + data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where + dimensions -3 and -2 are the spatial dimensions, and the final dimension has size + 2 (for complex values). + mask_func (callable): A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed (int or 1-d array_like, optional): Seed for the random number generator. + + Returns: + (tuple): tuple containing: + masked data (torch.Tensor): Subsampled k-space data + mask (torch.Tensor): The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + + +def mask_center(x, mask_from, mask_to): + mask = torch.zeros_like(x) + mask[:, :, :, mask_from:mask_to] = x[:, :, :, mask_from:mask_to] + + return mask + + +def center_crop(data, shape): + """ + Apply a center crop to the input real image or batch of real images. + + Args: + data (torch.Tensor): The input tensor to be center cropped. It should + have at least 2 dimensions and the cropping is applied along the + last two dimensions. + shape (int, int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image. + """ + assert 0 < shape[0] <= data.shape[-2] + assert 0 < shape[1] <= data.shape[-1] + + w_from = (data.shape[-2] - shape[0]) // 2 + h_from = (data.shape[-1] - shape[1]) // 2 + w_to = w_from + shape[0] + h_to = h_from + shape[1] + + return data[..., w_from:w_to, h_from:h_to] + + +def complex_center_crop(data, shape): + """ + Apply a center crop to the input image or batch of complex images. + + Args: + data (torch.Tensor): The complex input tensor to be center cropped. It + should have at least 3 dimensions and the cropping is applied along + dimensions -3 and -2 and the last dimensions should have a size of + 2. + shape (int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image + """ + assert 0 < shape[0] <= data.shape[-3] + assert 0 < shape[1] <= data.shape[-2] + + w_from = (data.shape[-3] - shape[0]) // 2 #80 + h_from = (data.shape[-2] - shape[1]) // 2 #80 + w_to = w_from + shape[0] #240 + h_to = h_from + shape[1] #240 + + return data[..., w_from:w_to, h_from:h_to, :] + + +def center_crop_to_smallest(x, y): + """ + Apply a center crop on the larger image to the size of the smaller. + + The minimum is taken over dim=-1 and dim=-2. If x is smaller than y at + dim=-1 and y is smaller than x at dim=-2, then the returned dimension will + be a mixture of the two. + + Args: + x (torch.Tensor): The first image. + y (torch.Tensor): The second image + + Returns: + tuple: tuple of tensors x and y, each cropped to the minimim size. + """ + smallest_width = min(x.shape[-1], y.shape[-1]) + smallest_height = min(x.shape[-2], y.shape[-2]) + x = center_crop(x, (smallest_height, smallest_width)) + y = center_crop(y, (smallest_height, smallest_width)) + + return x, y + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + +class DataTransform(object): + """ + Data Transformer for training U-Net models. + """ + + def __init__(self, which_challenge): + """ + Args: + which_challenge (str): Either "singlecoil" or "multicoil" denoting + the dataset. + mask_func (fastmri.data.subsample.MaskFunc): A function that can + create a mask of appropriate shape. + use_seed (bool): If true, this class computes a pseudo random + number generator seed from the filename. This ensures that the + same mask is used for all the slices of a given volume every + time. + """ + if which_challenge not in ("singlecoil", "multicoil"): + raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"') + + self.which_challenge = which_challenge + + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + """ + Args: + kspace (numpy.array): Input k-space of shape (num_coils, rows, + cols, 2) for multi-coil data or (rows, cols, 2) for single coil + data. + mask (numpy.array): Mask from the test dataset. + target (numpy.array): Target image. + attrs (dict): Acquisition related information stored in the HDF5 + object. + fname (str): File name. + slice_num (int): Serial number of the slice. + + Returns: + (tuple): tuple containing: + image (torch.Tensor): Zero-filled input image. + target (torch.Tensor): Target image converted to a torch + Tensor. + mean (float): Mean value used for normalization. + std (float): Standard deviation value used for normalization. + fname (str): File name. + slice_num (int): Serial number of the slice. + """ + kspace = to_tensor(kspace) + + # inverse Fourier transform to get zero filled solution + image = ifft2c(kspace) + + # crop input to correct size + if target is not None: + crop_size = (target.shape[-2], target.shape[-1]) + else: + crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) + + # check for sFLAIR 203 + if image.shape[-2] < crop_size[1]: + crop_size = (image.shape[-2], image.shape[-2]) + + image = complex_center_crop(image, crop_size) + + # getLR + imgfft = fft2c(image) + imgfft = complex_center_crop(imgfft, (160, 160)) + LR_image = ifft2c(imgfft) + + # absolute value + LR_image = complex_abs(LR_image) + + # normalize input + LR_image, mean, std = normalize_instance(LR_image, eps=1e-11) + LR_image = LR_image.clamp(-6, 6) + + # normalize target + if target is not None: + target = to_tensor(target) + target = center_crop(target, crop_size) + target = normalize(target, mean, std, eps=1e-11) + target = target.clamp(-6, 6) + else: + target = torch.Tensor([0]) + + return LR_image, target, mean, std, fname, slice_num + +class DenoiseDataTransform(object): + def __init__(self, size, noise_rate): + super(DenoiseDataTransform, self).__init__() + self.size = (size, size) + self.noise_rate = noise_rate + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + max_value = attrs["max"] + + #target + target = to_tensor(target) + target = center_crop(target, self.size) + target, mean, std = normalize_instance(target, eps=1e-11) + target = target.clamp(-6, 6) + + #image + kspace = to_tensor(kspace) + complex_image = ifft2c(kspace) #complex_image + image = complex_center_crop(complex_image, self.size) + noise_image = self.rician_noise(image, max_value) + noise_image = complex_abs(noise_image) + + noise_image = normalize(noise_image, mean, std, eps=1e-11) + noise_image = noise_image.clamp(-6, 6) + + return noise_image, target, mean, std, fname, slice_num + + + def rician_noise(self, X, noise_std): + #Add rician noise with variance sampled uniformly from the range 0 and 0.1 + noise_std = random.uniform(0, noise_std*self.noise_rate) + Ir = X + noise_std * torch.randn(X.shape) + Ii = noise_std*torch.randn(X.shape) + In = torch.sqrt(Ir ** 2 + Ii ** 2) + return In + + +def apply_mask( + data: torch.Tensor, + mask_func: MaskFunc, + seed: Optional[Union[int, Tuple[int, ...]]] = None, + padding: Optional[Sequence[int]] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Subsample given k-space by multiplying with a mask. + Args: + data: The input k-space data. This should have at least 3 dimensions, + where dimensions -3 and -2 are the spatial dimensions, and the + final dimension has size 2 (for complex values). + mask_func: A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed: Seed for the random number generator. + padding: Padding value to apply for mask. + Returns: + tuple containing: + masked data: Subsampled k-space data + mask: The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + + + +class ReconstructionTransform(object): + """ + Data Transformer for training U-Net models. + """ + + def __init__(self, which_challenge, mask_func=None, use_seed=True): + """ + Args: + which_challenge (str): Either "singlecoil" or "multicoil" denoting + the dataset. + mask_func (fastmri.data.subsample.MaskFunc): A function that can + create a mask of appropriate shape. + use_seed (bool): If true, this class computes a pseudo random + number generator seed from the filename. This ensures that the + same mask is used for all the slices of a given volume every + time. + """ + if which_challenge not in ("singlecoil", "multicoil"): + raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"') + + self.mask_func = mask_func + self.which_challenge = which_challenge + self.use_seed = use_seed + + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + """ + Args: + kspace (numpy.array): Input k-space of shape (num_coils, rows, + cols, 2) for multi-coil data or (rows, cols, 2) for single coil + data. + mask (numpy.array): Mask from the test dataset. + target (numpy.array): Target image. + attrs (dict): Acquisition related information stored in the HDF5 + object. + fname (str): File name. + slice_num (int): Serial number of the slice. + + Returns: + (tuple): tuple containing: + image (torch.Tensor): Zero-filled input image. + target (torch.Tensor): Target image converted to a torch + Tensor. + mean (float): Mean value used for normalization. + std (float): Standard deviation value used for normalization. + fname (str): File name. + slice_num (int): Serial number of the slice. + """ + kspace = to_tensor(kspace) + + # apply mask + if self.mask_func: + seed = None if not self.use_seed else tuple(map(ord, fname)) + masked_kspace, mask = apply_mask(kspace, self.mask_func, seed) + else: + masked_kspace = kspace + + # inverse Fourier transform to get zero filled solution + image = ifft2c(masked_kspace) + + # crop input to correct size + if target is not None: + crop_size = (target.shape[-2], target.shape[-1]) + else: + crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) + + # check for sFLAIR 203 + if image.shape[-2] < crop_size[1]: + crop_size = (image.shape[-2], image.shape[-2]) + + image = complex_center_crop(image, crop_size) + # print('image',image.shape) + # absolute value + image = complex_abs(image) + + # apply Root-Sum-of-Squares if multicoil data + if self.which_challenge == "multicoil": + image = rss(image) + + # normalize input + image, mean, std = normalize_instance(image, eps=1e-11) + image = image.clamp(-6, 6) + + # normalize target + if target is not None: + target = to_tensor(target) + target = center_crop(target, crop_size) + target = normalize(target, mean, std, eps=1e-11) + target = target.clamp(-6, 6) + else: + target = torch.Tensor([0]) + + return image, target, mean, std, fname, slice_num + + +def build_transforms(MASKTYPE, CENTER_FRACTIONS, ACCELERATIONS, mode = 'train'): + + challenge = 'singlecoil' + return ReconstructionTransform(challenge) + + # if mode == 'train': + # mask = create_mask_for_mask_type( + # MASKTYPE, CENTER_FRACTIONS, ACCELERATIONS, + # ) + # return ReconstructionTransform(challenge, mask, use_seed=False) + # + # elif mode == 'val' or mode == 'test': + # mask = create_mask_for_mask_type( + # MASKTYPE, CENTER_FRACTIONS, ACCELERATIONS, + # ) + # return ReconstructionTransform(challenge, mask) + # + # else: + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/__init__.py b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..946aef9fd093eb73d770679740b159912e0dad4d --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/__init__.py @@ -0,0 +1,2 @@ +from diffusion_pytorch.diffusion_gaussian import GaussianDiffusion, Trainer +# from diffusion_pytorch.new_twobranch_model import Model as TwoBranchNewModel diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/__init__.py b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/brats_mask.py b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/brats_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..2a70094da5195c6a78ed118bc5e8b352adf1b5b6 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/brats_mask.py @@ -0,0 +1,305 @@ +""" +Xiaohan Xing, 2023/04/08 +对BRATS 2020数据集进行Pre-processing, 得到各个模态的under-sampled input image和2d groung-truth. +""" +import os +import argparse +import numpy as np +import nibabel as nib +from scipy import ndimage as nd +from scipy import ndimage +from skimage import filters +from skimage import io +import torch +import torch.fft +from matplotlib import pyplot as plt + +MRIDOWN=8 +SNR = 0 + + +class MaskFunc_Cartesian: + """ + MaskFunc creates a sub-sampling mask of a given shape. + The mask selects a subset of columns from the input k-space data. If the k-space data has N + columns, the mask picks out: + a) N_low_freqs = (N * center_fraction) columns in the center corresponding to + low-frequencies + b) The other columns are selected uniformly at random with a probability equal to: + prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). + This ensures that the expected number of columns selected is equal to (N / acceleration) + It is possible to use multiple center_fractions and accelerations, in which case one possible + (center_fraction, acceleration) is chosen uniformly at random each time the MaskFunc object is + called. + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there + is a 50% probability that 4-fold acceleration with 8% center fraction is selected and a 50% + probability that 8-fold acceleration with 4% center fraction is selected. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be retained. + If multiple values are provided, then one of these numbers is chosen uniformly + each time. + accelerations (List[int]): Amount of under-sampling. This should have the same length + as center_fractions. If multiple values are provided, then one of these is chosen + uniformly each time. An acceleration of 4 retains 25% of the columns, but they may + not be spaced evenly. + """ + if len(center_fractions) != len(accelerations): + raise ValueError('Number of center fractions should match number of accelerations') + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random.RandomState() + + def __call__(self, shape, seed=None): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The shape should have + at least 3 dimensions. Samples are drawn along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting the seed + ensures the same mask is generated each time for the same shape. + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError('Shape should have 3 or more dimensions') + + self.rng.seed(seed) + num_cols = shape[-2] + + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + # Create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs + 1e-10) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad:pad + num_low_freqs] = True + + # Reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + mask = mask.repeat(shape[0], 1, 1) + + return mask + + +## mri related +def mri_fourier_transform_2d(image, mask): + ''' + image: input tensor [B, H, W, C] + mask: mask tensor [H, W] + ''' + spectrum = torch.fft.fftn(image, dim=(1, 2)) + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + spectrum = torch.fft.fftshift(spectrum, dim=(1, 2)) + # Downsample k-space + spectrum = spectrum * mask[None, :, :, None] + return spectrum + +## mri related +def mri_inver_fourier_transform_2d(spectrum): + ''' + image: input tensor [B, H, W, C] + ''' + spectrum = torch.fft.ifftshift(spectrum, dim=(1, 2)) + image = torch.fft.ifftn(spectrum, dim=(1, 2)) + + return image + + +def get_undersample(): + ff = MaskFunc_Cartesian([0.2], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4 + shape = [240, 240, 1] + mask = ff(shape, seed=1337) + mask = mask[:, :, 0] + + + plt.imshow(mask) + plt.show() + + +def simulate_undersample_mri(raw_mri): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + ff = MaskFunc_Cartesian([0.2], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4 + shape = [240, 240, 1] + mask = ff(shape, seed=1337) + mask = mask[:, :, 0] + # print("original MRI:", mri) + + # print("original MRI:", mri.shape) + kspace = mri_fourier_transform_2d(mri, mask) + kspace = add_gaussian_noise(kspace) + mri_recon = mri_inver_fourier_transform_2d(kspace) + kdata = torch.sqrt(kspace.real ** 2 + kspace.imag ** 2 + 1e-10) + kdata = kdata.data.numpy()[0, :, :, 0] + + under_img = torch.sqrt(mri_recon.real ** 2 + mri_recon.imag ** 2) + under_img = under_img.data.numpy()[0, :, :, 0] + + return under_img, kspace + + +def add_gaussian_noise(img, snr=15): + ### 根据SNR确定noise的放大比例 + num_pixels = img.shape[0]*img.shape[1]*img.shape[2]*img.shape[3] + psr = torch.sum(torch.abs(img.real)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + noise_r = torch.randn_like(img.real)*np.sqrt(pnr) + + psim = torch.sum(torch.abs(img.imag)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = torch.randn_like(img.imag)*np.sqrt(pnim) + + noise = noise_r + 1j*noise_im + noise_img = img + noise + # print("original image:", img) + # print("gaussian noise:", noise) + + return noise_img + + +def complexsing_addnoise(img, snr): + ### add noise to the real part of the image. + img_numpy = img.cpu().numpy() + # print("kspace data:", img) + s_r = np.real(img_numpy) + num_pixels = s_r.shape[0]*s_r.shape[1]*s_r.shape[2]*s_r.shape[3] + psr = np.sum(np.abs(s_r)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + # print("PSR:", psr, "PNR:", pnr) + noise_r = np.random.randn(num_pixels)*np.sqrt(pnr) + + ### add noise to the iamginary part of the image. + s_im = np.imag(img_numpy) + psim = np.sum(np.abs(s_im)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = np.random.randn(num_pixels)*np.sqrt(pnim) + + noise = torch.Tensor(noise_r) + 1j*torch.Tensor(noise_im) + sn = img + noise + # print("noisy data:", sn) + # sn = torch.Tensor(sn) + + return sn + + + +def _parse(rootdir): + filetree = {} + + for sample_file in os.listdir(rootdir): + sample_dir = rootdir + sample_file + subject = sample_file + + for filename in os.listdir(sample_dir): + modality = filename.split('.').pop(0).split('_')[-1] + + if subject not in filetree: + filetree[subject] = {} + filetree[subject][modality] = filename + + return filetree + + + +def clean(rootdir, savedir, source_modality, target_modality): + filetree = _parse(rootdir) + print("filetree:", filetree) + + if not os.path.exists(savedir+'/img_norm'): + os.makedirs(savedir+'/img_norm') + + for subject, modalities in filetree.items(): + print(f'{subject}:') + + if source_modality not in modalities or target_modality not in modalities: + print('-> incomplete') + continue + + source_path = os.path.join(rootdir, subject, modalities[source_modality]) + target_path = os.path.join(rootdir, subject, modalities[target_modality]) + + source_image = nib.load(source_path) + target_image = nib.load(target_path) + + source_volume = source_image.get_fdata() + target_volume = target_image.get_fdata() + source_binary_volume = np.zeros_like(source_volume) + target_binary_volume = np.zeros_like(target_volume) + + print("source volume:", source_volume.shape) + print("target volume:", target_volume.shape) + + for i in range(source_binary_volume.shape[-1]): + source_slice = source_volume[:, :, i] + target_slice = target_volume[:, :, i] + + if source_slice.min() == source_slice.max(): + print("invalide source slice") + source_binary_volume[:, :, i] = np.zeros_like(source_slice) + else: + source_binary_volume[:, :, i] = ndimage.morphology.binary_fill_holes( + source_slice > filters.threshold_li(source_slice)) + + if target_slice.min() == target_slice.max(): + print("invalide target slice") + target_binary_volume[:, :, i] = np.zeros_like(target_slice) + else: + target_binary_volume[:, :, i] = ndimage.morphology.binary_fill_holes( + target_slice > filters.threshold_li(target_slice)) + + source_volume = np.where(source_binary_volume, source_volume, np.ones_like( + source_volume) * source_volume.min()) + target_volume = np.where(target_binary_volume, target_volume, np.ones_like( + target_volume) * target_volume.min()) + ## resize + if source_image.header.get_zooms()[0] < 0.6: + scale = np.asarray([240, 240, source_volume.shape[-1]]) / np.asarray(source_volume.shape) + source_volume = nd.zoom(source_volume, zoom=scale, order=3, prefilter=False) + target_volume = nd.zoom(target_volume, zoom=scale, order=0, prefilter=False) + + # save volume into images + source_volume = (source_volume-source_volume.min())/(source_volume.max()-source_volume.min()) + target_volume = (target_volume-target_volume.min())/(target_volume.max()-target_volume.min()) + + for i in range(source_binary_volume.shape[-1]): + source_binary_slice = source_binary_volume[:, :, i] + target_binary_slice = target_binary_volume[:, :, i] + if source_binary_slice.max() > 0 and target_binary_slice.max() > 0: + dd = target_volume.shape[0] // 2 + target_slice = target_volume[dd - 120:dd + 120, dd - 120:dd + 120, i] + source_slice = source_volume[dd - 120:dd + 120, dd - 120:dd + 120, i] + print("source slice range:", source_slice.shape) + print("target slice range:", target_slice.max(), target_slice.min()) + # undersample MRI + source_under_img, source_kspace = simulate_undersample_mri(source_slice) + target_under_img, target_kspace = simulate_undersample_mri(target_slice) + + # # io.imsave(savedir+'/img_norm/'+subject+'_'+str(i)+'_'+source_modality+'.png', (source_slice * 255.0).astype(np.uint8)) + io.imsave(savedir + '/img_temp/' + subject + '_' + str(i) + '_' + source_modality + '_' + str(MRIDOWN) + 'X_' + str(SNR) + 'dB_undermri.png', + (source_under_img * 255.0).astype(np.uint8)) + + # io.imsave(savedir + '/img_temp/' + subject + '_' + str(i) + '_' + source_modality + '_' + str(MRIDOWN) + 'X_undermri.png', + # (source_under_img * 255.0).astype(np.uint8)) + # # io.imsave(savedir+'/img_norm/'+subject+'_'+str(i)+'_'+target_modality+'.png', (target_slice * 255.0).astype(np.uint8)) + # io.imsave(savedir + '/img_norm/' + subject + '_' + str(i) + '_' + target_modality + '_' + str(MRIDOWN) + 'X_undermri.png', + # (target_under_img * 255.0).astype(np.uint8)) + + # np.savez_compressed(rootdir + '/img_norm/' + subject + '_' + str(i) + '_' + target_modality + '_raw_'+str(MRIDOWN)+'X'+str(CTNVIEW)+'P', + # kspace=kspace, under_t1=under_img, + # t1=source_slice, ct=target_slice) + + +def main(args): + clean(args.rootdir,args.savedir, args.source, args.target) + + +if __name__ == '__main__': + get_undersample() diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/example_mask/brats_4X_mask.npy b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/example_mask/brats_4X_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..bdf32304f95640286541ceb1068582dc69b0d60a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/example_mask/brats_4X_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76341ba680a0bc9c80389e01f8511e5bd99ab361eeb48d83516904b84cccc518 +size 460928 diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/example_mask/brats_8X_mask.npy b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/example_mask/brats_8X_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..c389e708adeb3307db90ff071599256b8f59dab5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/example_mask/brats_8X_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c5160add079e8f4dc2496e5ef87c110015026d9f6116329da2238a73d8bc104 +size 230528 diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/extract_example_mask.py b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/extract_example_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..630c51d74f70c00fba605fd06761eb9a73c9d3e9 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/extract_example_mask.py @@ -0,0 +1,80 @@ +import matplotlib.pyplot as plt +import torch +import numpy as np +from torch.fft import fft2, ifft2, fftshift, ifftshift + +# brats 4X +example = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_4X/BraTS20_Training_036_90_t2_4X_undermri.png" +gt = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_4X/BraTS20_Training_036_90_t2.png" +save_file = "./example_mask/brats_4X_mask.npy" + + +# example = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_8X/BraTS20_Training_036_90_t2_8X_undermri.png" +# gt = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_8X/BraTS20_Training_036_90_t2.png" +# save_file = "./example_mask/brats_8X_mask.npy" + + +example_img = plt.imread(example) # cv2.imread(example_frequency_img, cv2.IMREAD_GRAYSCALE) +gt = plt.imread(gt) # cv2.imread(example_frequency_img, cv2.IMREAD_GRAYSCALE) + +print("example_img shape: ", example_img.shape) +plt.imshow(example_img, cmap='gray') +plt.title("Example Frequency Image") +plt.show() + +example_img = torch.from_numpy(example_img).float() +fre = fftshift(fft2(example_img)) # ) + +amp = torch.log(torch.abs(fre)) +angle = torch.angle(fre) + +plt.imshow(amp.squeeze(0).squeeze(0).numpy()) +plt.show() + +plt.imshow(angle.squeeze(0).squeeze(0).numpy()) +plt.show() + + +gt_fre = fftshift(fft2(torch.from_numpy(gt).float())) # ) +gt_amp = torch.log(torch.abs(gt_fre)) +plt.imshow(gt_amp.squeeze(0).squeeze(0).numpy()) +plt.show() +gt_angle = torch.angle(gt_fre) +plt.imshow(gt_angle.squeeze(0).squeeze(0).numpy()) +plt.show() + + +amp_mask = gt_amp.squeeze(0).squeeze(0).numpy() - amp.squeeze(0).squeeze(0).numpy() +amp_mask = np.mean(amp_mask, axis=0, keepdims=True) + +thres = np.mean(amp_mask) * 0.73 + + +new_mask = (amp_mask < thres) * 1.0 +new_mask = np.repeat(new_mask, 240, axis=0) + +amp_mask[amp_mask < thres] = 1 +amp_mask[amp_mask >= thres] = 0 + + +#duplicate +amp_mask = np.repeat(amp_mask, 240, axis=0) + +plt.imshow(gt_amp.squeeze(0).squeeze(0).numpy() - amp.squeeze(0).squeeze(0).numpy()) +plt.show() +plt.imshow(gt_angle.squeeze(0).squeeze(0).numpy() - angle.squeeze(0).squeeze(0).numpy()) +plt.show() +plt.imshow(new_mask) +plt.show() + +np.save(save_file, new_mask) + + + +load_backmask = np.load(save_file) +plt.imshow(load_backmask) +plt.show() + +size = load_backmask.shape[0] * load_backmask.shape[1] +print("shape=", size, load_backmask.sum()/size) + diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/k_degradation.py b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/k_degradation.py new file mode 100644 index 0000000000000000000000000000000000000000..6c74f69a9437565dede0875af404e234819445af --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/k_degradation.py @@ -0,0 +1,512 @@ +# --------------------------- +# Fade kernels +# --------------------------- +import cv2, torch +import numpy as np +import torchgeometry as tgm +from torch.fft import fft2, ifft2, fftshift, ifftshift +try: + from diffusion_pytorch.degradation.mask_utils import RandomMaskFunc, EquispacedMaskFractionFunc, EquiSpacedMaskFunc, RandomPatchFunc +except: + from mask_utils import RandomMaskFunc, EquispacedMaskFractionFunc, EquiSpacedMaskFunc, RandomPatchFunc + +from torch import nn +import matplotlib.pyplot as plt + +def get_fade_kernel(dims, std): + fade_kernel = tgm.image.get_gaussian_kernel2d(dims, std) + fade_kernel = fade_kernel / torch.max(fade_kernel) + fade_kernel = torch.ones_like(fade_kernel) - fade_kernel + # if device_of_kernel == 'cuda': + # fade_kernel = fade_kernel.cuda() + fade_kernel = fade_kernel[1:, 1:] + return fade_kernel + + + +def get_fade_kernels(fade_routine, num_timesteps, image_size, kernel_std,initial_mask): + kernels = [] + for i in range(num_timesteps): + if fade_routine == 'Incremental': + kernels.append(get_fade_kernel((image_size + 1, image_size + 1), + (kernel_std * (i + initial_mask), + kernel_std * (i + initial_mask)))) + elif fade_routine == 'Constant': + kernels.append(get_fade_kernel( + (image_size + 1, image_size + 1), + (kernel_std, kernel_std))) + + elif fade_routine == 'Random_Incremental': + kernels.append(get_fade_kernel((2 * image_size + 1, 2 * image_size + 1), + (kernel_std * (i + initial_mask), + kernel_std * (i + initial_mask)))) + return torch.stack(kernels) + + +# --------------------------- +# Kspace kernels +# --------------------------- +# cartesian_regular +def get_mask_func(mask_method, af, cf): + if mask_method == 'cartesian_regular': + return EquispacedMaskFractionFunc(center_fractions=[cf], accelerations=[af]) + + elif mask_method == 'cartesian_random': + return RandomMaskFunc(center_fractions=[cf], accelerations=[af]) + + elif mask_method == "random": + return RandomMaskFunc([cf], [af]) + + elif mask_method == "randompatch": + return RandomPatchFunc([cf], [af]) + + elif mask_method == "equispaced": + return EquiSpacedMaskFunc([cf], [af]) + + else: + raise NotImplementedError + + +use_fix_center_ratio = False + +class Noisy_Patch(nn.Module): + def __init__(self): + super(Noisy_Patch, self).__init__() + self.af_list = [] + self.cf_list = [] + self.fe_list = [] + self.pe_list = [] + self.seed = 0 + + def append_list(self, at, cf, fe, pe): + self.af_list.append(at) + self.cf_list.append(cf) + self.fe_list.append(fe) + self.pe_list.append(pe) + + def get_noisy_patches(self, t): + af = self.af_list[t] + cf = self.cf_list[t] + fe = self.fe_list[t] + pe = self.pe_list[t] + + patch_mask = get_mask_func("randompatch", af, cf) + mask_, _ = patch_mask((fe, pe, 1), seed=self.seed) # mask (numpy): (fe, pe) + return mask_ + + def forward(self, mask, ts): + # Step 1, Random Drop elements to learn intra-stride effects, patch-wise + # print("use_patch_kernel forward:", t) + # print("mask = ", mask.shape) + # masks_ = [] + for id, t in enumerate(ts): + mask_ = self.get_noisy_patches(t)[0] + # print("mask_ = ", mask_.shape) + # print("mask[id, t] =", mask[t].shape) + + mask[t] = mask_.to(mask[t].device) * mask[t] + self.seed += ts[0].item() + + # masks_ = torch.stack(masks_).cuda() + # print("masks_ = ", masks_.shape) + # print("mask = ", mask.shape) # B, T, H, W + + return mask + +get_noisy_patches = Noisy_Patch() + + +def get_ksu_mask(mask_method, af, cf, pe, fe, seed=0, is_training=False): + mask_func = get_mask_func(mask_method, af, cf) # acceleration factor, center fraction + + if mask_method in ['cartesian_regular', 'cartesian_random', 'equispaced']: + + mask, num_low_freq = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + # prepare noisy patches + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, fe, pe) + + elif mask_method == 'equispaced': + mask, num_low_freq = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + # prepare noisy patches + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, fe, pe) + + elif mask_method == 'gaussian_2d': + mask, _ = mask_func(resolution=(fe, pe), accel=af, sigma=100, seed=seed) # mask (numpy): (fe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (fe, pe) --> (1, fe, pe) + + elif mask_method in ['radial_add', 'radial_sub', 'spiral_add', 'spiral_sub']: + sr = 1 / af + mask = mask_func(mask_type=mask_method, mask_sr=sr, res=pe, seed=seed) # mask (numpy): (pe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (pe, pe) --> (1, pe, pe) + + else: + raise NotImplementedError + + # print("return mask = ", mask.shape) + return mask + + + +def get_ksu_kernel(timesteps, image_size, + ksu_routine="LogSamplingRate", + mask_method="cartesian_random", + accelerated_factor=4, is_training = False, example_frequency_img=None): + + + if accelerated_factor == 4: + if is_training: + mask_method, center_fraction = "cartesian_random", 0.08 # 0.15 + else: + mask_method, center_fraction = "cartesian_random", 0.08 # 0.08 + + elif accelerated_factor == 8: + if is_training: + mask_method, center_fraction = "equispaced", 0.04 # 0.04 + else: + mask_method, center_fraction = "equispaced", 0.04 + + + center_ratio_factor = center_fraction * accelerated_factor + + masks = [] + noisy_masks = [] + ksu_mask_pe = ksu_mask_fe = image_size # , ksu_mask_pe=320, ksu_mask_fe=320 + # ksu_mask_fe + if ksu_routine == 'LinearSamplingRate': + # Generate the sampling rate list with torch.linspace, reversed, and skip the first element + sr_list = torch.linspace(start=1/accelerated_factor, end=1, steps=timesteps + 1).flip(0) + # Start from 0.01 + for sr in sr_list: + af = 1 / sr # * accelerated_factor # acceleration factor + cf = center_fraction if use_fix_center_ratio else sr_list[0] * center_ratio_factor + + masks.append(get_ksu_mask(mask_method, af, cf, pe=ksu_mask_pe, fe=ksu_mask_fe, is_training=is_training)) + + elif ksu_routine == 'LogSamplingRate': + # MRI-Specific Masking: + # Design the frequency masking schedule to prioritize central k-space frequencies early in the process. + # Central k-space contains low-frequency information critical for image contrast. + + + # Generate the sampling rate list with torch.logspace, reversed, and skip the first element + sr_list = torch.logspace(start=-torch.log10(torch.tensor(accelerated_factor)), + end=0, steps=timesteps + 1).flip(0) + + af = 1 / sr_list[-1] + cf = center_fraction if use_fix_center_ratio else sr_list[-1] * center_ratio_factor + # print("af = ", af, cf) + + # Full + if isinstance(example_frequency_img, str): + # read in image and get frequency space: + example_img = plt.imread(example_frequency_img) #cv2.imread(example_frequency_img, cv2.IMREAD_GRAYSCALE) + print("example_img shape: ", example_img.shape) + plt.imshow(example_img, cmap='gray') + plt.title("Example Frequency Image") + plt.show() + + example_img = torch.from_numpy(example_img).float() + fre = fftshift(fft2(example_img) ) # ) + amp = torch.log(torch.abs(fre)) + plt.imshow(amp.squeeze(0).squeeze(0).numpy()) + plt.show() + angle= torch.angle(fre) + plt.imshow(angle.squeeze(0).squeeze(0).numpy()) + plt.show() + + + cache_mask = get_ksu_mask(mask_method, af, cf, pe=ksu_mask_pe, fe=ksu_mask_fe) + + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, ksu_mask_fe, ksu_mask_pe) + + masks.append(cache_mask) + + sr_list = sr_list[:-1].flip(0) # Flip? + + for sr in sr_list: + af = 1 / sr + cf = center_fraction if use_fix_center_ratio else sr * center_ratio_factor + # print("af = ", af, cf) + + H, W = cache_mask.shape[1], cache_mask.shape[2] + new_mask = cache_mask.clone() + + # Add additional lines to the mask based on new acceleration factor + total_lines = H + sampled_lines = int(total_lines / af) + existing_lines = new_mask.squeeze(0).sum(dim=0).nonzero(as_tuple=True)[0].tolist() + + remaining_lines = [i for i in range(total_lines) if i not in existing_lines] + + if sampled_lines > len(existing_lines): + center = W // 2 + additional_lines = sampled_lines - len(existing_lines) # sample number + + sorted_indices = sorted(remaining_lines, key=lambda x: abs(x - center)) + + # Take the closest `additional_lines` indices + sampled_indices = sorted_indices[:additional_lines] + + # Remove sampled indices from remaining_lines + for idx in sampled_indices: + remaining_lines.remove(idx) + + # Update new_mask for each sampled index + for idx in sampled_indices: + new_mask[:, :, idx] = 1.0 + + # if sampled_lines > len(existing_lines): + # additional_lines = sampled_lines - len(existing_lines) # sample number + # + # # Random line + # # sampled_indices = np.random.choice(remaining_lines, additional_lines, replace=False) + # + # # Close to the center + # center = W // 2 # Calculate the center index + # # Find the indices of zeros closest to the center + # sampled_indices = sorted(remaining_lines, key=lambda x: abs(x - center))[0] + # remaining_lines.remove(sampled_indices) + # + # # sampled_indices = + # new_mask[:, :, sampled_indices] = 1.0 + + + + cache_mask = new_mask + + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, ksu_mask_fe, ksu_mask_pe) + + + masks.append(cache_mask) + + # reverse + masks = masks[::-1] + noisy_masks = masks # noisy_masks[::-1] + + + elif mask_method == 'gaussian_2d': + raise NotImplementedError("Gaussian 2D mask type is not implemented.") + + else: + raise NotImplementedError(f'Unknown k-space undersampling routine {ksu_routine}') + + # Return masks, excluding the first one + return masks#, noisy_masks[1:] + + + +class high_fre_mask: + def __init__(self): + self.mask_cache = {} + + def __call__(self, H, W): + if (H, W) in self.mask_cache: + return self.mask_cache[(H, W)] + center_x, center_y = H // 2, W // 2 + radius = H//8 # 影响的频率范围半径 + + high_freq_mask = torch.ones(H, W) + for i in range(H): + for j in range(W): + if (i - center_x) ** 2 + (j - center_y) ** 2 <= radius ** 2: + high_freq_mask[i, j] = 0.0 + self.mask_cache[(H, W)] = high_freq_mask + return high_freq_mask + + +high_fre_mask_cls = high_fre_mask() + + + +def apply_ksu_kernel(x_start, mask, params_dict=None, pixel_range='mean_std', + use_fre_noise=False): + fft, mask = apply_tofre(x_start, mask) + + + # Use the high frequency mask to add noise + if use_fre_noise: + fft = fft * mask + fft_magnitude = torch.abs(fft) # 幅度 + fft_phase = torch.angle(fft) # 相位 + + _, _, H, W = fft.shape + + high_freq_mask = high_fre_mask_cls(H, W).to(fft.device) + high_freq_mask = high_freq_mask.unsqueeze(0).unsqueeze(0).repeat(fft.shape[0], 1, 1, 1) + + # Background Noise + sigma = 0.1 * torch.abs(torch.randn(1)).item() + noise = torch.randn_like(fft_magnitude) * sigma + # noise_magnitude = sigma * fft_magnitude # fft_magnitude.mean() + mean_mag = fft_magnitude.sum() / (mask.sum() + 1) + # print("mean_mag = ", mean_mag) + + noise_magnitude_high = noise * (mean_mag) * (1 - mask) # high_freq_mask + + # Add noise to unmasked frequencies to maintain the stochasticity that diffusion models typically rely on. + # This prevents overfitting to specific frequency ranges. + sigma = 5/255 * torch.abs(torch.randn(1)).item() + noise = torch.randn_like(fft_magnitude) * sigma + noise_magnitude_low = noise * fft_magnitude * mask # (1 - high_freq_mask) + + # fft_noisy_magnitude = fft_magnitude * mask + noise_magnitude * high_freq_mask * (1 - mask) + fft_noisy_magnitude = fft_magnitude * mask + fft_noisy_magnitude += noise_magnitude_low # + noise_magnitude_high + fft_noisy_magnitude = torch.clamp(fft_noisy_magnitude, min=0.0) + + fft = fft_noisy_magnitude * torch.exp(1j * fft_phase) + + else: + fft = fft * mask + + + x_ksu = apply_to_spatial(fft) + + return x_ksu + + + +def apply_tofre(x_start, mask): + kspace = fftshift(fft2(x_start, norm=None, dim=(-2, -1)), dim=(-2, -1)) # Default: all dimensions + mask = mask.to(kspace.device) + return kspace, mask + +def apply_to_spatial(fft): + + x_ksu = ifft2(ifftshift(fft, dim=(-2, -1)), norm=None, dim=(-2, -1)) # ortho + x_ksu = x_ksu.real #torch.abs(x_ksu) # + + return x_ksu + + +if __name__ == "__main__": + # First STEP + import matplotlib.pyplot as plt + import numpy as np + + image_size = 64 + + masks = get_ksu_kernel(25, image_size, + "LinearSamplingRate", is_training=True) # LogSamplingRate + + + batch_size = 1 + + img = plt.imread("/Users/haochen/Documents/GitHub/Frequency-Diffusion/draw/assets/BraTS20_Training_001_86_t1.png") + img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LINEAR) + # to gray scale + # img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + + # img = np.transpose(img, (2, 0, 1)) + # img = img[0] + img = np.expand_dims(img, axis=0) + img = torch.from_numpy(img).unsqueeze(0).float() + print(" img shape: ", img.shape, img.max(), img.min()) + + rand_kernels = [] + rand_x = torch.randint(0, image_size + 1, (batch_size,)).long() + print("rand_x shape:", rand_x.shape, rand_x) + + img = img * 2 - 1 # + + masked_img = [] + + # masks = np.asarray(masks) + for m in masks: + print("m shape: ", m.shape) + m = m.unsqueeze(0) + img = apply_ksu_kernel(img, m, pixel_range='-1_1', ) + masked_img.append(img) + + masks = np.concatenate(masks, axis=-1)[0] + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + # masked_img = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY) + + print(" masked_img shape: ", masked_img.shape) + print(" mask shape: ", masks.shape) + + img = np.concatenate([masks, masked_img], axis=0) + + plt.figure(figsize=(100, 10)) + plt.imshow(img, cmap='gray') # (1, 128, 1280) + plt.show() + + print("\n\nSecond stage...") + + + # Second STEP + import matplotlib.pyplot as plt + import numpy as np + + image_size = 64 + batch_size = 1 + t = 25 + kspace_kernels = get_ksu_kernel(t, image_size, ksu_routine="LogSamplingRate", is_training=True) # 2 * + kspace_kernels = torch.stack(kspace_kernels).squeeze(1) + + img = plt.imread( + "/Users/haochen/Documents/GitHub/Frequency-Diffusion/draw/assets/BraTS20_Training_001_86_t1.png") + img = cv2.resize(img, (image_size, image_size)) + + # img = np.transpose(img, (2, 0, 1)) + # img = img[0] + + img = np.expand_dims(img, axis=0) + img = torch.from_numpy(img).unsqueeze(0).float() + print(" img shape: ", img.shape, img.max(), img.min()) + + rand_kernels = [] + rand_x = torch.randint(0, image_size + 1, (batch_size,)).long() + + print("rand_x shape:", rand_x.shape, rand_x) + + for i in range(batch_size): + print("kspace_kernels[j] shape = ", kspace_kernels[i].shape, rand_x[i]) + # k = kspace_kernels[j].clone() + rand_kernels.append(torch.stack( + [kspace_kernels[j][: image_size, # rand_x[i]:rand_x[i] + + : image_size] for j in + range(len(kspace_kernels))])) + + rand_kernels = torch.stack(rand_kernels) + + # rand_kernels shape: torch.Size([24, 5, 128, 128]) + print("=== rand_kernels: ", rand_kernels.shape, kspace_kernels[0].shape) + + masked_img = [] + masks = [] + for i in range(t): + k = torch.stack([rand_kernels[:, i]], 1)[0] + masks.append(k) + # print("-- k shape: ", k.shape) + # print("-- img shape: ", img.shape) + + img = apply_ksu_kernel(img, k, pixel_range='0_1') + masked_img.append(img) + + masks = np.concatenate(masks, axis=-1)[0] + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + # masked_img = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY) + + # print(" masked_img shape: ", masked_img.shape) + # print(" mask shape: ", masks.shape) + + img = np.concatenate([masks, masked_img], axis=0) + + plt.figure(figsize=(100, 10)) + plt.imshow(img, cmap='gray') # (1, 128, 1280) + plt.show() + diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/kspace_test.py b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/kspace_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1b00fc4c3af61773497301d2bc5344642c1b4a9a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/kspace_test.py @@ -0,0 +1,272 @@ +# --------------------------- +# Fade kernels +# --------------------------- +import cv2, torch +import numpy as np +import torchgeometry as tgm +from torch.fft import fft2, ifft2, fftshift, ifftshift +import matplotlib.pyplot as plt + +try: + from diffusion_pytorch.degradation.mask_utils import RandomMaskFunc, EquispacedMaskFractionFunc, EquiSpacedMaskFunc, \ + RandomPatchFunc +except: + from mask_utils import RandomMaskFunc, EquispacedMaskFractionFunc, EquiSpacedMaskFunc, RandomPatchFunc + +try: + from .k_degradation import get_fade_kernel, get_fade_kernels, get_mask_func, high_fre_mask, get_ksu_kernel +except: + from k_degradation import get_fade_kernel, get_fade_kernels, get_mask_func, high_fre_mask, get_ksu_kernel + + +use_fix_center_ratio = False + + +def get_ksu_mask(mask_method, af, cf, pe, fe, seed=0, is_training=False): + mask_func = get_mask_func(mask_method, af, cf) # acceleration factor, center fraction + + if mask_method in ['cartesian_regular', 'cartesian_random']: + + mask, num_low_freq = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + if is_training: # Step 1, Random Drop elements to learn intra-stride effects, patch-wise + af_new = 1.0 + (af - 1.0) / 2 + # af_new = max(af_new, 1.0) + + patch_mask = get_mask_func("randompatch", af_new, cf) + mask_, _ = patch_mask((fe, pe, 1), seed=seed) # mask (numpy): (fe, pe) + + mask = mask_ * mask + + + elif mask_method == 'gaussian_2d': + mask, _ = mask_func(resolution=(fe, pe), accel=af, sigma=100, seed=seed) # mask (numpy): (fe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (fe, pe) --> (1, fe, pe) + + elif mask_method in ['radial_add', 'radial_sub', 'spiral_add', 'spiral_sub']: + sr = 1 / af + mask = mask_func(mask_type=mask_method, mask_sr=sr, res=pe, seed=seed) # mask (numpy): (pe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (pe, pe) --> (1, pe, pe) + + else: + raise NotImplementedError + + # print("return mask = ", mask.shape) + return mask + + +# ksu_masks = get_ksu_kernels() +# (C, H, W) --> (B, C, H, W) + + +high_fre_mask_cls = high_fre_mask() + + +def apply_ksu_kernel(x_start, mask, params_dict=None, pixel_range='mean_std', + use_fre_noise=False, return_mask=False): + fft, mask = apply_tofre(x_start, mask, params_dict, pixel_range) + + # Use the high frequency mask to add noise + if use_fre_noise: + fft = fft * mask + fft_magnitude = torch.abs(fft) # 幅度 + fft_phase = torch.angle(fft) # 相位 + + _, _, H, W = fft.shape + + high_freq_mask = high_fre_mask_cls(H, W).to(fft.device) + high_freq_mask = high_freq_mask.unsqueeze(0).unsqueeze(0).repeat(fft.shape[0], 1, 1, 1) + + # Background Noise + sigma = 0.2 + noise = torch.randn_like(fft_magnitude) * sigma + mean_mag = fft_magnitude.sum() / (mask.sum() + 1) + + noise_magnitude_high = noise * (mean_mag) * (1 - mask) # high_freq_mask + + sigma = 0.1 + noise = torch.randn_like(fft_magnitude) * sigma + noise_magnitude_low = noise * fft_magnitude * mask # (1 - high_freq_mask) + + # fft_noisy_magnitude = fft_magnitude * mask + noise_magnitude * high_freq_mask * (1 - mask) + fft_noisy_magnitude = fft_magnitude * mask + fft_noisy_magnitude += noise_magnitude_high + noise_magnitude_low + fft_noisy_magnitude = torch.clamp(fft_noisy_magnitude, min=0.0) + + fft = fft_noisy_magnitude * torch.exp(1j * fft_phase) + + else: + fft = fft * mask + + x_ksu = apply_to_spatial(fft, params_dict, pixel_range) + if return_mask: + return x_ksu, fft, fft_magnitude + + return x_ksu + + +def apply_tofre(x_start, mask, params_dict=None, pixel_range='mean_std'): + fft = fftshift(fft2(x_start)) + mask = mask.to(fft.device) + return fft, mask # , _min, _max + + +def apply_to_spatial(fft, params_dict=None, pixel_range='mean_std'): + x_ksu = ifft2(ifftshift(fft)) + x_ksu = torch.abs(x_ksu) + + return x_ksu + + +if __name__ == "__main__": + # First STEP + import SimpleITK as sitk + + import numpy as np + import os + + image_size = 240 + batch_size = 1 + t = 5 + + + use_linux = True + + # Load MRI back here + if use_linux: + root = "/gamedrive/Datasets/medical/Brain/brats/ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData" + p_id = 639 + modality = "T1C" + filename = f"{root}/BraTS-GLI-{p_id:05d}-000/BraTS-GLI-{p_id:05d}-000-{modality.lower()}.nii.gz" + img_obj = sitk.ReadImage(filename) + img_array = sitk.GetArrayFromImage(img_obj) + + slice = img_array.shape[0] // 2 + img = img_array[slice, ...] + plt.imshow(img, cmap="gray") + plt.show() + img = (img - img.min()) / (img.max() - img.min()) + + plt.imsave("visualization/original.png", img, cmap="gray") + + else: + # Or use PNG + img = plt.imread( + "/Users/haochen/Documents/GitHub/DiffDecomp/Cold-Diffusion/generation-diffusion-pytorch/diffusion_pytorch/assets/img.png") + img = np.transpose(img, (2, 0, 1))[0] + + + img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LINEAR) + print("img shape=", img.shape) + + img = np.expand_dims(img, axis=0) + img = torch.from_numpy(img).unsqueeze(0).float() + + ksu_routine = "LogSamplingRate" # "LinearSamplingRate" # + kspace_kernels, patch_drop_masks = get_ksu_kernel(t, image_size, + ksu_routine=ksu_routine, is_training=True, + example_frequency_img=example) + kspace_kernels = torch.stack(kspace_kernels).squeeze(1) + + # all k_space + rand_kernels = [] + rand_x = torch.randint(0, image_size + 1, (batch_size,)).long() + + for i in range(batch_size): + # k = kspace_kernels[j].clone() + rand_kernels.append(torch.stack( + [kspace_kernels[j][: image_size, # rand_x[i]:rand_x[i] + + : image_size] for j in + range(len(kspace_kernels))])) + + rand_kernels = torch.stack(rand_kernels) + + # rand_kernels shape: torch.Size([24, 5, 128, 128]) + ori_img = img + + masked_img = [] + masks = [] + for i in range(t): + k = torch.stack([rand_kernels[:, i]], 1)[0] + masks.append(k) + + img = apply_ksu_kernel(img, k, pixel_range='0_1') + + masked_img.append(img) + + # Visualize the masks + masks = np.concatenate(masks, axis=-1)[0] + + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + # Save individually + + print("masks / masked_img=", masks.max(), masked_img.max()) + # img = np.concatenate([masks, masked_img], axis=0) + + plt.imsave("visualization/sample_masks.png", masks, cmap='gray') + + # masked_img = (masked_img - masked_img.min())/(masked_img) + # masked_img = np.concatenate([masked_img, 1-masked_img], axis=0) + plt.imsave("visualization/sample_images.png", masked_img, cmap='gray') + + w = masked_img.shape[0] + pr_folder = "visualization/progressive" + os.makedirs(pr_folder, exist_ok=True) + + # Progressive + print() + for i in range(t): + plt.imsave(f"{pr_folder}/{i}_masked_img.png", masked_img[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_masks.png", masks[:, i * w: (i + 1) * w], cmap='gray') + + img = ori_img + # use_fre_noise=False, return_mask=False + masked_img = [] + masks = [] + fft = [] + ks = [] + for i in range(t): + k = torch.stack([rand_kernels[:, i]], 1)[0] + ks.append(k) + + img, k, fft_original = apply_ksu_kernel(img, k, pixel_range='0_1', use_fre_noise=True, return_mask=True) + + # k -> fft + fft_magnitude = np.abs(k) # 幅度 + # fft_phase = torch.angle(k) # 相位 + + mag = np.log(fft_magnitude[0]) + masks.append(mag) + fft.append(np.log(fft_original[0])) + + masked_img.append(img) + + # Visualize the masks + masks = np.concatenate(masks, axis=-1)[0] + ks = np.concatenate(ks, axis=-1)[0] + + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + + fft = np.concatenate(fft, axis=-1)[0] + + plt.imsave("visualization/sample_noisy_mask.png", masks, cmap='gray') + + # masked_img = np.concatenate([masked_img, 1 - masked_img], axis=0) + plt.imsave("visualization/sample_noisy_image.png", masked_img, cmap='gray') + # print("masked_img shape=", masked_img.shape, w) + + # Progressive + for i in range(t): + # print("masked_img[:, t*w: (t+1)*w] = ", masked_img[:, t*w: (t+1)*w].shape, t*w) + + plt.imsave(f"{pr_folder}/{i}_n_masked_img.png", masked_img[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_n_masks.png", masks[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_n_fft.png", fft[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_n_ks.png", ks[:, i * w: (i + 1) * w], cmap='gray') + + + diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/mask_utils.py b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2fea4433a26d0e67e3f81119a67d43a6e46598 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/mask_utils.py @@ -0,0 +1,680 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import contextlib +from typing import Optional, Sequence, Tuple, Union + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng: np.random.RandomState, seed: Optional[Union[int, Tuple[int, ...]]]): + """A context manager for temporarily adjusting the random seed.""" + if seed is None: + try: + yield + finally: + pass + else: + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +class MaskFunc: + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + + When called, ``MaskFunc`` uses internal functions create mask by 1) + creating a mask for the k-space center, 2) create a mask outside of the + k-space center, and 3) combining them into a total mask. The internals are + handled by ``sample_mask``, which calls ``calculate_center_mask`` for (1) + and ``calculate_acceleration_mask`` for (2). The combination is executed + in the ``MaskFunc`` ``__call__`` function. + + If you would like to implement a new mask, simply subclass ``MaskFunc`` + and overwrite the ``sample_mask`` logic. See examples in ``RandomMaskFunc`` + and ``EquispacedMaskFunc``. + """ + + def __init__( + self, + center_fractions: Sequence[float], + accelerations: Sequence[int], + allow_any_combination: bool = False, + seed: Optional[int] = None, + ): + """ + Args: + center_fractions: Fraction of low-frequency columns to be retained. + If multiple values are provided, then one of these numbers is + chosen uniformly each time. + accelerations: Amount of under-sampling. This should have the same + length as center_fractions. If multiple values are provided, + then one of these is chosen uniformly each time. + allow_any_combination: Whether to allow cross combinations of + elements from ``center_fractions`` and ``accelerations``. + seed: Seed for starting the internal random number generator of the + ``MaskFunc``. + """ + if len(center_fractions) != len(accelerations) and not allow_any_combination: + raise ValueError( + "Number of center fractions should match number of accelerations " + "if allow_any_combination is False." + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.allow_any_combination = allow_any_combination + self.rng = np.random.RandomState(seed) + + def __call__( + self, + shape: Sequence[int], + offset: Optional[int] = None, + seed: Optional[Union[int, Tuple[int, ...]]] = None, + ) -> Tuple[torch.Tensor, int]: + """ + Sample and return a k-space mask. + + Args: + shape: Shape of k-space. + offset: Offset from 0 to begin mask (for equispaced masks). If no + offset is given, then one is selected randomly. + seed: Seed for random number generator for reproducibility. + + Returns: + A 2-tuple containing 1) the k-space mask and 2) the number of + center frequency lines. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_mask, accel_mask, num_low_frequencies = self.sample_mask( + shape, offset + ) + + # combine masks together + return torch.max(center_mask, accel_mask), num_low_frequencies + + def sample_mask( + self, + shape: Sequence[int], + offset: Optional[int], + ) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Sample a new k-space mask. + + This function samples and returns two components of a k-space mask: 1) + the center mask (e.g., for sensitivity map calculation) and 2) the + acceleration mask (for the edge of k-space). Both of these masks, as + well as the integer of low frequency samples, are returned. + + Args: + shape: Shape of the k-space to subsample. + offset: Offset from 0 to begin mask (for equispaced masks). + + Returns: + A 3-tuple contaiing 1) the mask for the center of k-space, 2) the + mask for the high frequencies of k-space, and 3) the integer count + of low frequency samples. + """ + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + num_low_frequencies = round(float(num_cols * center_fraction)) + center_mask = self.reshape_mask( + self.calculate_center_mask(shape, num_low_frequencies), shape + ) + acceleration_mask = self.reshape_mask( + self.calculate_acceleration_mask( + num_cols, acceleration, offset, num_low_frequencies + ), + shape, + ) + + return center_mask, acceleration_mask, num_low_frequencies + + def reshape_mask(self, mask: np.ndarray, shape: Sequence[int]) -> torch.Tensor: + """Reshape mask to desired output shape.""" + num_cols = shape[-2] + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + if isinstance(mask, torch.Tensor): + return mask.view(*mask_shape).to(torch.float32) + return torch.from_numpy(mask.reshape(*mask_shape)).to(torch.float32) # torch.from_numpy( + + def calculate_acceleration_mask( + self, + num_cols: int, + acceleration: int, + offset: Optional[int], + num_low_frequencies: int, + ) -> np.ndarray: + """ + Produce mask for non-central acceleration lines. + + Args: + num_cols: Number of columns of k-space (2D subsampling). + acceleration: Desired acceleration rate. + offset: Offset from 0 to begin masking (for equispaced masks). + num_low_frequencies: Integer count of low-frequency lines sampled. + + Returns: + A mask for the high spatial frequencies of k-space. + """ + raise NotImplementedError + + def calculate_center_mask( + self, shape: Sequence[int], num_low_freqs: int + ) -> np.ndarray: + """ + Build center mask based on number of low frequencies. + + Args: + shape: Shape of k-space to mask. + num_low_freqs: Number of low-frequency lines to sample. + + Returns: + A mask for hte low spatial frequencies of k-space. + """ + num_cols = shape[-2] + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = 1 + assert mask.sum() == num_low_freqs + + return mask + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + if self.allow_any_combination: + return self.rng.choice(self.center_fractions), self.rng.choice( + self.accelerations + ) + else: + choice = self.rng.randint(len(self.center_fractions)) + return self.center_fractions[choice], self.accelerations[choice] + + + + +class RandomMaskFunc(MaskFunc): + """ + Creates a random sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the ``RandomMaskFunc`` object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def calculate_acceleration_mask( + self, + num_cols: int, + acceleration: int, + offset: Optional[int], + num_low_frequencies: int, + ) -> np.ndarray: + + prob = (num_cols / acceleration - num_low_frequencies) / ( + num_cols - num_low_frequencies + ) + + # mask = self.rng.uniform(size=num_cols) < prob + # return torch.from_numpy(mask.astype(np.float32)) + + # return self.rng.uniform(size=num_cols) < prob + return torch.rand(num_cols) < prob + + + # mask = self.rng.uniform(size=num_cols) < prob + # pad = (num_cols - num_low_freqs + 1) // 2 + # mask[pad: pad + num_low_freqs] = True + # + # # reshape the mask + # mask_shape = [1 for _ in shape] + # mask_shape[-2] = num_cols + # mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + + + +class RandomPatchFunc(MaskFunc): + """ + Creates a random sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the ``RandomMaskFunc`` object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def reshape_mask(self, mask: np.ndarray, shape: Sequence[int]) -> torch.Tensor: + """Reshape mask to desired output shape.""" + # num_cols = shape[0] * shape[1] + mask_shape = [1 for _ in shape] + mask_shape[-2] = shape[0] #num_cols + mask_shape[-1] = shape[1] + + if isinstance(mask, torch.Tensor): + return mask.view(*mask_shape).to(torch.float32) + return torch.from_numpy(mask.reshape(*mask_shape)).to(torch.float32) # torch.from_numpy( + + + def calculate_acceleration_mask( + self, + num_cols: int, + acceleration: int, + offset: Optional[int], + num_low_frequencies: int, + ) -> np.ndarray: + + prob = (num_cols / acceleration - num_low_frequencies) / ( + num_cols - num_low_frequencies + ) + + # mask = self.rng.uniform(size=num_cols) < prob + # return torch.from_numpy(mask.astype(np.float32)) + + # return self.rng.uniform(size=num_cols) < prob + return torch.rand(num_cols) < prob + + + def calculate_center_mask( + self, shape: Sequence[int], num_low_freqs: int + ) -> np.ndarray: + """ + Build center mask based on number of low frequencies. + + Args: + shape: Shape of k-space to mask. + num_low_freqs: Number of low-frequency lines to sample. + + Returns: + A mask for hte low spatial frequencies of k-space. + """ + # print("shape = ", shape) + num_cols = shape[0] * shape[1] + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad: pad + num_low_freqs] = 1 + assert mask.sum() == num_low_freqs + + return mask + + + def sample_mask( + self, + shape: Sequence[int], + offset: Optional[int], + ) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Sample a new k-space mask. + + This function samples and returns two components of a k-space mask: 1) + the center mask (e.g., for sensitivity map calculation) and 2) the + acceleration mask (for the edge of k-space). Both of these masks, as + well as the integer of low frequency samples, are returned. + + Args: + shape: Shape of the k-space to subsample. + offset: Offset from 0 to begin mask (for equispaced masks). + + Returns: + A 3-tuple contaiing 1) the mask for the center of k-space, 2) the + mask for the high frequencies of k-space, and 3) the integer count + of low frequency samples. + """ + # print("sample mask shape= ", shape) + + + num_cols = shape[1] * shape[0] + center_fraction, acceleration = self.choose_acceleration() + num_low_frequencies = round(float(num_cols * center_fraction)) + center_mask = self.reshape_mask( + self.calculate_center_mask(shape, num_low_frequencies), shape + ) + acceleration_mask = self.reshape_mask( + self.calculate_acceleration_mask( + num_cols, acceleration, offset, num_low_frequencies + ), + shape, + ) + + return center_mask, acceleration_mask, num_low_frequencies + + + def __call__( + self, + shape: Sequence[int], + offset: Optional[int] = None, + seed: Optional[Union[int, Tuple[int, ...]]] = None, + ) -> Tuple[torch.Tensor, int]: + """ + Sample and return a k-space mask. + + Args: + shape: Shape of k-space. + offset: Offset from 0 to begin mask (for equispaced masks). If no + offset is given, then one is selected randomly. + seed: Seed for random number generator for reproducibility. + + Returns: + A 2-tuple containing 1) the k-space mask and 2) the number of + center frequency lines. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_mask, accel_mask, num_low_frequencies = self.sample_mask( + shape, offset + ) + + # combine masks together + return torch.max(center_mask, accel_mask), num_low_frequencies + + + + + + +class EquiSpacedMaskFunc(MaskFunc): + """ + Sample data with equally-spaced k-space lines. + + The lines are spaced exactly evenly, as is done in standard GRAPPA-style + acquisitions. This means that with a densely-sampled center, + ``acceleration`` will be greater than the true acceleration rate. + """ + + def calculate_acceleration_mask( + self, + num_cols: int, + acceleration: int, + offset: Optional[int], + num_low_frequencies: int, + ) -> np.ndarray: + """ + Produce mask for non-central acceleration lines. + + Args: + num_cols: Number of columns of k-space (2D subsampling). + acceleration: Desired acceleration rate. + offset: Offset from 0 to begin masking. If no offset is specified, + then one is selected randomly. + num_low_frequencies: Not used. + + Returns: + A mask for the high spatial frequencies of k-space. + """ + if not isinstance(acceleration, int): + acceleration = int(acceleration.item()) + + if offset is None: + offset = self.rng.randint(0, high=round(acceleration)) + + mask = np.zeros(num_cols, dtype=np.float32) + mask[offset::acceleration] = 1 + + return mask + + +class EquispacedMaskFractionFunc(MaskFunc): + """ + Equispaced mask with approximate acceleration matching. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def calculate_acceleration_mask( + self, + num_cols: int, + acceleration: int, + offset: Optional[int], + num_low_frequencies: int, + ) -> np.ndarray: + """ + Produce mask for non-central acceleration lines. + + Args: + num_cols: Number of columns of k-space (2D subsampling). + acceleration: Desired acceleration rate. + offset: Offset from 0 to begin masking. If no offset is specified, + then one is selected randomly. + num_low_frequencies: Number of low frequencies. Used to adjust mask + to exactly match the target acceleration. + + Returns: + A mask for the high spatial frequencies of k-space. + """ + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_frequencies - num_cols)) / ( + num_low_frequencies * acceleration - num_cols + ) + if offset is None: + offset = self.rng.randint(0, high=round(adjusted_accel)) + + mask = np.zeros(num_cols) + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = 1.0 + + return mask + + +class MagicMaskFunc(MaskFunc): + """ + Masking function for exploiting conjugate symmetry via offset-sampling. + + This function applies the mask described in the following paper: + + Defazio, A. (2019). Offset Sampling Improves Deep Learning based + Accelerated MRI Reconstructions by Exploiting Symmetry. arXiv preprint, + arXiv:1912.01101. + + It is essentially an equispaced mask with an offset for the opposite site + of k-space. Since MRI images often exhibit approximate conjugate k-space + symmetry, this mask is generally more efficient than a standard equispaced + mask. + + Similarly to ``EquispacedMaskFunc``, this mask will usually undereshoot the + target acceleration rate. + """ + + def calculate_acceleration_mask( + self, + num_cols: int, + acceleration: int, + offset: Optional[int], + num_low_frequencies: int, + ) -> np.ndarray: + """ + Produce mask for non-central acceleration lines. + + Args: + num_cols: Number of columns of k-space (2D subsampling). + acceleration: Desired acceleration rate. + offset: Offset from 0 to begin masking. If no offset is specified, + then one is selected randomly. + num_low_frequencies: Not used. + + Returns: + A mask for the high spatial frequencies of k-space. + """ + if offset is None: + offset = self.rng.randint(0, high=acceleration) + + if offset % 2 == 0: + offset_pos = offset + 1 + offset_neg = offset + 2 + else: + offset_pos = offset - 1 + 3 + offset_neg = offset - 1 + 0 + + poslen = (num_cols + 1) // 2 + neglen = num_cols - (num_cols + 1) // 2 + mask_positive = np.zeros(poslen, dtype=np.float32) + mask_negative = np.zeros(neglen, dtype=np.float32) + + mask_positive[offset_pos::acceleration] = 1 + mask_negative[offset_neg::acceleration] = 1 + mask_negative = np.flip(mask_negative) + + mask = np.concatenate((mask_positive, mask_negative)) + + return np.fft.fftshift(mask) # shift mask and return + + +class MagicMaskFractionFunc(MagicMaskFunc): + """ + Masking function for exploiting conjugate symmetry via offset-sampling. + + This function applies the mask described in the following paper: + + Defazio, A. (2019). Offset Sampling Improves Deep Learning based + Accelerated MRI Reconstructions by Exploiting Symmetry. arXiv preprint, + arXiv:1912.01101. + + It is essentially an equispaced mask with an offset for the opposite site + of k-space. Since MRI images often exhibit approximate conjugate k-space + symmetry, this mask is generally more efficient than a standard equispaced + mask. + + Similarly to ``EquispacedMaskFractionFunc``, this method exactly matches + the target acceleration by adjusting the offsets. + """ + + def sample_mask( + self, + shape: Sequence[int], + offset: Optional[int], + ) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Sample a new k-space mask. + + This function samples and returns two components of a k-space mask: 1) + the center mask (e.g., for sensitivity map calculation) and 2) the + acceleration mask (for the edge of k-space). Both of these masks, as + well as the integer of low frequency samples, are returned. + + Args: + shape: Shape of the k-space to subsample. + offset: Offset from 0 to begin mask (for equispaced masks). + + Returns: + A 3-tuple contaiing 1) the mask for the center of k-space, 2) the + mask for the high frequencies of k-space, and 3) the integer count + of low frequency samples. + """ + num_cols = shape[-2] + fraction_low_freqs, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_frequencies = round(num_cols * fraction_low_freqs) + + # bound the number of low frequencies between 1 and target columns + target_columns_to_sample = round(num_cols / acceleration) + num_low_frequencies = max(min(num_low_frequencies, target_columns_to_sample), 1) + + # adjust acceleration rate based on target acceleration. + adjusted_target_columns_to_sample = ( + target_columns_to_sample - num_low_frequencies + ) + adjusted_acceleration = 0 + if adjusted_target_columns_to_sample > 0: + adjusted_acceleration = round(num_cols / adjusted_target_columns_to_sample) + + center_mask = self.reshape_mask( + self.calculate_center_mask(shape, num_low_frequencies), shape + ) + accel_mask = self.reshape_mask( + self.calculate_acceleration_mask( + num_cols, adjusted_acceleration, offset, num_low_frequencies + ), + shape, + ) + + return center_mask, accel_mask, num_low_frequencies + + +def create_mask_for_mask_type( + mask_type_str: str, + center_fractions: Sequence[float], + accelerations: Sequence[int], +) -> MaskFunc: + """ + Creates a mask of the specified type. + + Args: + center_fractions: What fraction of the center of k-space to include. + accelerations: What accelerations to apply. + + Returns: + A mask func for the target mask type. + """ + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquiSpacedMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced_fraction": + return EquispacedMaskFractionFunc(center_fractions, accelerations) + elif mask_type_str == "magic": + return MagicMaskFunc(center_fractions, accelerations) + elif mask_type_str == "magic_fraction": + return MagicMaskFractionFunc(center_fractions, accelerations) + else: + raise ValueError(f"{mask_type_str} not supported") diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/original.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/original.png new file mode 100644 index 0000000000000000000000000000000000000000..8e9661372201bbd5c809478e5060f7c89c408c69 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/original.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/0_masks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/0_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..61c1d7d7c6ecbb701b85747cd6d2faf1e2fde9b6 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/0_masks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/0_n_fft.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/0_n_fft.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/0_n_fft.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/0_n_ks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/0_n_ks.png new file mode 100644 index 0000000000000000000000000000000000000000..61c1d7d7c6ecbb701b85747cd6d2faf1e2fde9b6 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/0_n_ks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/0_n_masked_img.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/0_n_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..1b74feb3ab2cf078b641e9369acc31db5dbd70e6 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/0_n_masked_img.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/0_n_masks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/0_n_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/0_n_masks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_masked_img.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..000a9816d664f7a6282d9da2ba81aabdd652103e Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_masked_img.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_masks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..d04fa7b7e31f5d43205e7b7f9d50afbd02a51c54 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_masks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_n_fft.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_n_fft.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_n_fft.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_n_ks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_n_ks.png new file mode 100644 index 0000000000000000000000000000000000000000..d04fa7b7e31f5d43205e7b7f9d50afbd02a51c54 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_n_ks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_n_masked_img.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_n_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..1f1e68643d94e2570045c15aced74c8f2e8c7823 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_n_masked_img.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_n_masks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_n_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/1_n_masks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_masked_img.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..3a323d0ddddc771fc6508161f93b919857c6eedb Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_masked_img.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_masks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..227825b2b4bf23e8d456e9b292af155e51f0c76f Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_masks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_n_fft.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_n_fft.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_n_fft.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_n_ks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_n_ks.png new file mode 100644 index 0000000000000000000000000000000000000000..227825b2b4bf23e8d456e9b292af155e51f0c76f Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_n_ks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_n_masked_img.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_n_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..2ff77d3c4df7dee5be4edc909c3980362b614f1d Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_n_masked_img.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_n_masks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_n_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/2_n_masks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_masked_img.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..a3d8492f7bbf1bb93c25f0e616f4000240f7fa5e Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_masked_img.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_masks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..3116c70ad5a72b9164c3da2f6fa2071e646372a2 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_masks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_n_fft.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_n_fft.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_n_fft.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_n_ks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_n_ks.png new file mode 100644 index 0000000000000000000000000000000000000000..3116c70ad5a72b9164c3da2f6fa2071e646372a2 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_n_ks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_n_masked_img.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_n_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..56c1877dc19f8ef25f99ffe435ba41940ac7f394 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_n_masked_img.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_n_masks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_n_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/3_n_masks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_masked_img.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..fb66bd61130361dc3ed8deb95ad6fb030135efd4 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_masked_img.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_masks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..7b6952955f0a71f402b4d1bd0851324cd79194ae Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_masks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_n_fft.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_n_fft.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_n_fft.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_n_ks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_n_ks.png new file mode 100644 index 0000000000000000000000000000000000000000..7b6952955f0a71f402b4d1bd0851324cd79194ae Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_n_ks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_n_masked_img.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_n_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..28d1b93788c84e2579d53425432e920fb298c190 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_n_masked_img.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_n_masks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_n_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/progressive/4_n_masks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/sample_images.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/sample_images.png new file mode 100644 index 0000000000000000000000000000000000000000..5547693cce0491b530cf67b50bf7bd91d5bd91c6 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/sample_images.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/sample_masks.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/sample_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..0db686e17102eb660c650a4081ef841b374741cc Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/sample_masks.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/sample_noisy_image.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/sample_noisy_image.png new file mode 100644 index 0000000000000000000000000000000000000000..afedd4450bab040a01ffaa6af5af63617b2e0747 Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/sample_noisy_image.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/sample_noisy_mask.png b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/sample_noisy_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..ff7c980e2daf9395368f8a73b44d3bb2f8efdc9a Binary files /dev/null and b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/degradation/visualization/sample_noisy_mask.png differ diff --git a/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/diffusion_gaussian.py b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/diffusion_gaussian.py new file mode 100644 index 0000000000000000000000000000000000000000..9560f60b2545cb40f4f903e884de42f0a014c8ce --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/diffusion_pytorch/diffusion_gaussian.py @@ -0,0 +1,2045 @@ +import copy, time +import gc + +import torch +from torch import nn +import torch.nn.functional as func +from inspect import isfunction +from functools import partial + +from torch.utils import data +from pathlib import Path +from torch.optim import AdamW, lr_scheduler +from torchvision import utils +import torch.nn.functional as F +# from einops import rearrange + +import os +import errno +from PIL import Image +# from pytorch_msssim import ssim +import cv2 +import numpy as np +import imageio +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + +# from torch.utils.tensorboard import SummaryWriter +from diffusion_pytorch.degradation.k_degradation import get_fade_kernels, get_ksu_kernel, apply_ksu_kernel, apply_tofre, apply_to_spatial +from dataset import Dataset, Dataset_Aug1, BrainDataset + +# from skimage.metrics import structural_similarity as ssim +from skimage.metrics import peak_signal_noise_ratio as psnr + +from torchmetrics.image import StructuralSimilarityIndexMeasure + +from metrics.lpips import LPIPS +from metrics.fid import calculate_fid +from metrics.fid_3d import calculate_fid_3d +from metrics.nmse import nmse +from diffusion_pytorch.degradation.k_degradation import get_noisy_patches +import torch.amp as amp +from torch.cuda.amp import GradScaler, autocast +from metrics.frequency_loss import AMPLoss +scaler = GradScaler() + +# try: +# from apex import amp +# +# APEX_AVAILABLE = True +# except: +# APEX_AVAILABLE = False + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def create_folder(path): + try: + os.mkdir(path) + except OSError as exc: + if exc.errno != errno.EEXIST: + raise + pass + + +def cycle(dl): + while True: + for inputs in dl: + yield inputs + + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + + +def loss_backwards(fp16, loss, optimizer, **kwargs): + if fp16: + scaler.scale(loss).backward(**kwargs) + else: + loss.backward(**kwargs) + + + +class EMA: + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + + + +class GaussianDiffusion(nn.Module): + def __init__( + self, + diffusion_type, + restore_fn, + *, + image_size, + device_of_kernel, + channels=3, + timesteps=1000, + loss_type='l1', + kernel_std=0.1, + initial_mask=11, + fade_routine='Incremental', + sampling_routine='default', + discrete=False, + accelerate_factor=4, + fp16=False, + normalizer="mean_std", + example_frequency_img=None + + ): + super().__init__() + self.fp16 = fp16 + self.channels = channels + self.image_size = image_size + self.restore_fn = restore_fn + self.accelerate_factor = accelerate_factor + + # self.backbone = diffusion_type.split('_')[0] + self.example_frequency_img = example_frequency_img + self.device_of_kernel = device_of_kernel + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + self.kernel_std = kernel_std + self.initial_mask = initial_mask + self.fade_routine = fade_routine + self.backbone = diffusion_type.split('_')[0] + self.degradation_type = diffusion_type.split('_')[1] + + self.sampling_routine = sampling_routine + self.discrete = discrete + + # Frequency Loss + if self.backbone == 'twobranch' or self.backbone == 'twounet': + self.amploss = AMPLoss() # .to(self.device, non_blocking=True) + self.kl_loss = torch.nn.KLDivLoss(reduction='sum') # sum , log_target=True + + self.lpips = LPIPS().eval().cuda() # .to(self.device, non_blocking=True) + self.ssim = StructuralSimilarityIndexMeasure() + + self.use_fre_loss = True + + self.use_ssim = False + self.use_lpips = True # 141 -> 144s + + self.clamp_every_sample = False # Stride + if normalizer == "min_max": + self.clamp_every_sample = True + + self.use_fre_noise = True + + self.update_kernel = False + self.use_patch_kernel = False + self.use_kl = False + + + if self.degradation_type == 'fade': + self.fade_kernels = get_fade_kernels(fade_routine, self.num_timesteps, image_size, kernel_std, initial_mask) + # print("=== self.fade_kernels shape = ", self.fade_kernels.shape) # [5, 256, 256] + + elif self.degradation_type == "kspace": + self.get_new_kspace(is_training=True) + # print("=== self.kspace_kernels shape = ", self.kspace_kernels.shape) # [5, 256, 256] + else: + raise NotImplementedError() + + def get_new_kspace(self, is_training=False): + # LinearSamplingRate, LogSamplingRate + self.kspace_kernels, self.noisy_kspace_kernels = get_ksu_kernel(self.num_timesteps, self.image_size, + ksu_routine="LogSamplingRate", is_training=is_training, + accelerated_factor=self.accelerate_factor, + example_frequency_img=self.example_frequency_img) + + self.kspace_kernels = torch.stack(self.kspace_kernels).squeeze(1).cuda() + self.noisy_kspace_kernels =self.kspace_kernels.cuda() + + + def get_kspace_kernels(self, index): + k = torch.stack([self.kspace_kernels[index]], 0).unsqueeze(0) + return k + + @torch.no_grad() + def sample(self, batch_size=16, faded_recon_sample=None, aux=None, + t=None, params_dict=None, sample_routine=None): + # Test + self.restore_fn.eval() + + if not sample_routine: + sample_routine = self.sampling_routine + + sample_device = faded_recon_sample.device + batch_size = faded_recon_sample.size(0) + + + + if t is None: + t = self.num_timesteps + + # print("self.kspace_kernels = ", self.kspace_kernels.shape) # self.kspace_kernels = torch.Size([5, 320, 320]) + # print("faded_recon_sample = ", faded_recon_sample.shape) # faded_recon_sample = torch.Size([16, 3, 320, 320]) + + # for i in range(t): + with torch.no_grad(): + k = torch.stack([self.kspace_kernels[[t - 1]]], 1) + # print("k = ", k.shape) # k = torch.Size([1, 1, 320, 320]) + faded_recon_sample = apply_ksu_kernel(faded_recon_sample, k, params_dict) + + return_k = k.repeat( batch_size, 1, 1, 1) + + xt = faded_recon_sample + # print("faded_recon_sample = ", faded_recon_sample.shape) + + + direct_recons = None + recon_sample = None + all_recons = [] + all_recons_fre = [] + all_masks = [] + + k_known_mask = torch.zeros_like(self.get_kspace_kernels(-1)).cuda() + + while t: + step = torch.full((batch_size,), t - 1, dtype=torch.long).cuda() + if self.backbone == "unet": + recon_sample = self.restore_fn(faded_recon_sample, step) + + elif self.backbone == "twounet": + recon_sample = self.restore_fn(faded_recon_sample, aux, k, step) + + elif self.backbone == "twobranch": + recon_sample, recon_fre = self.restore_fn(faded_recon_sample, aux, step) + all_recons_fre.append(recon_fre) + + if direct_recons is None: + direct_recons = recon_sample + all_recons.append(recon_sample) + + if self.degradation_type == 'kspace': + # faded_recon_sample = recon_sample + + if sample_routine == 'default': + all_recons.append(recon_sample) + with torch.no_grad(): + if t >=1: + k = self.get_kspace_kernels(t - 1) + recon_sample = apply_ksu_kernel(recon_sample, k, params_dict) + + all_masks.append(k) + faded_recon_sample = recon_sample + + + elif sample_routine == 'x0_step_down': + all_recons.append(recon_sample) + if t <= 1: + if t == 1: + # recon_sample_sub_1 = recon_sample + # k = self.get_kspace_kernels(0, self.kspace_kernels) + # + # recon_sample = apply_ksu_kernel(recon_sample, k, params_dict) + faded_recon_sample = recon_sample #faded_recon_sample - recon_sample + recon_sample_sub_1 + + else: + faded_recon_sample = recon_sample + all_masks.append(k) + else: + with torch.no_grad(): + k = self.get_kspace_kernels(t - 2, self.kspace_kernels) + recon_sample_sub_1 = apply_ksu_kernel(recon_sample, k, params_dict) + + k = self.get_kspace_kernels(t - 1, self.kspace_kernels) + recon_sample = apply_ksu_kernel(recon_sample, k, params_dict) + + faded_recon_sample = faded_recon_sample - recon_sample + recon_sample_sub_1 + + if self.clamp_every_sample: + faded_recon_sample = faded_recon_sample.clamp(-1, 1) + all_masks.append(k) + + elif sample_routine == 'x0_step_down_fre': + all_recons.append(recon_sample) + if t <= 1: + kt = self.get_kspace_kernels(0).cuda() # last one + faded_recon_sample = recon_sample + k_residual = torch.ones_like(kt).cuda() + + + else: + k_full = self.get_kspace_kernels(- 1) + faded_recon_sample_fre, k_full = apply_tofre(faded_recon_sample, k_full, params_dict) + # print('k_full = ', k_full.shape) + + with torch.no_grad(): + + kt_sub_1 = self.get_kspace_kernels(t - 1).cuda() + kt = self.get_kspace_kernels(t - 0).cuda() # last one + k_residual = kt_sub_1 - kt + recon_sample_fre, k_residual = apply_tofre(recon_sample, k_residual, params_dict) + + fre_amend = recon_sample_fre * k_residual + faded_recon_sample_fre = faded_recon_sample_fre + fre_amend # * (1-k_residual) + + faded_recon_sample_fre = faded_recon_sample_fre * kt_sub_1 + + faded_recon_sample = apply_to_spatial(faded_recon_sample_fre, params_dict) + + k_known_mask += k_residual #.cpu() + all_masks.append(k_known_mask.cpu().clone()) + + if self.clamp_every_sample: + faded_recon_sample = faded_recon_sample.clamp(-1, 1) + + + elif sample_routine == 'fre_progressive': + all_recons.append(recon_sample) + if t == 1: + recon_sample_sub_1 = recon_sample + k = self.get_kspace_kernels(0, self.kspace_kernels) + + recon_sample = apply_ksu_kernel(recon_sample, k, params_dict) + faded_recon_sample = faded_recon_sample - recon_sample + recon_sample_sub_1 + + elif t == 0: + faded_recon_sample = recon_sample + all_masks.append(k) + else: + k_full = self.get_kspace_kernels(- 1, self.kspace_kernels) + faded_recon_sample_fre, k_full = apply_tofre(faded_recon_sample, k_full, params_dict) + + with torch.no_grad(): + + kt_sub_1 = self.get_kspace_kernels(t - 2, self.kspace_kernels).cuda() + kt = self.get_kspace_kernels(t - 1, self.kspace_kernels).cuda() # last one + k_residual = kt_sub_1 - k_full + recon_sample_fre, k_residual = apply_tofre(recon_sample, k_residual, params_dict) + + + fre_amend = recon_sample_fre * k_residual # new + faded_recon_sample_fre = faded_recon_sample_fre * k_full + fre_amend # * (1-k_residual) + + # faded_recon_sample_fre = faded_recon_sample_fre * kt_sub_1 # + recon_sample * (1 - kt_sub_1) + + faded_recon_sample = apply_to_spatial(faded_recon_sample_fre, params_dict) + + k_known_mask += k_residual # .cpu() + all_masks.append(k_known_mask.cpu().clone()) + + if self.clamp_every_sample: + faded_recon_sample = faded_recon_sample.clamp(-1, 1) + + + + recon_sample = faded_recon_sample + # print("recon_sample = ", recon_sample.shape) + + t -= 1 + + all_recons = torch.stack(all_recons) + all_masks = torch.stack(all_masks) + all_recons_fre = torch.stack(all_recons_fre) + + return xt, direct_recons, recon_sample, return_k, all_recons, all_recons_fre, all_masks + + @torch.no_grad() + def all_sample(self, batch_size=16, faded_recon_sample=None, aux=None, t=None, params_dict=None, times=None): + # TODO + print("Running into all_sample...") + rand_kernels = None + sample_device = faded_recon_sample.device + if self.degradation_type == 'fade': + if 'Random' in self.fade_routine: + rand_kernels = [] + rand_x = torch.randint(0, self.image_size + 1, (batch_size,), device=faded_recon_sample.device).long() + rand_y = torch.randint(0, self.image_size + 1, (batch_size,), device=faded_recon_sample.device).long() + for i in range(batch_size, ): + rand_kernels.append(torch.stack( + [self.fade_kernels[j][rand_x[i]:rand_x[i] + self.image_size, + rand_y[i]:rand_y[i] + self.image_size] for j in range(len(self.fade_kernels))])) + rand_kernels = torch.stack(rand_kernels) + + elif self.degradation_type == 'kspace': + rand_kernels = [] + rand_x = torch.randint(0, self.image_size + 1, (batch_size,), device=faded_recon_sample.device).long() + + for i in range(batch_size, ): + rand_kernels.append(torch.stack( + [self.fade_kernels[j][rand_x[i]:rand_x[i] + self.image_size, + : self.image_size] for j in range(len(self.fade_kernels))])) + rand_kernels = torch.stack(rand_kernels) + + if t is None: + t = self.num_timesteps + if times is None: + times = t + + for i in range(t): + with torch.no_grad(): + if self.degradation_type == 'fade': + if 'Random' in self.fade_routine: + faded_recon_sample = torch.stack([rand_kernels[:, i].to(sample_device), + rand_kernels[:, i].to(sample_device), + rand_kernels[:, i].to(sample_device)], 1) * faded_recon_sample + else: + faded_recon_sample = self.fade_kernels[i].to(sample_device) * faded_recon_sample + elif self.degradation_type == 'kspace': + if rand_kernels is not None: + # print(f"kspace randkeynel k={rand_kernels[:, i].shape}, x={x.shape}") + k = torch.stack([rand_kernels[:, i]], 1) + faded_recon_sample = apply_ksu_kernel(faded_recon_sample, k, params_dict) + else: + # print(f"kspace k={self.kspace_kernels[i].shape}, x={x.shape}") + k = self.kspace_kernels[i] + faded_recon_sample = apply_ksu_kernel(faded_recon_sample, k, params_dict) + + if self.discrete: + faded_recon_sample = (faded_recon_sample + 1) * 0.5 + faded_recon_sample = (faded_recon_sample * 255) + faded_recon_sample = faded_recon_sample.int().float() / 255 + faded_recon_sample = faded_recon_sample * 2 - 1 + + x0_list = [] + xt_list = [] + + while times: + step = torch.full((batch_size,), times - 1, dtype=torch.long).cuda() + if self.backbone == "unet": + recon_sample = self.restore_fn(faded_recon_sample, step) + elif self.backbone == "twounet": + recon_sample = self.restore_fn(faded_recon_sample, aux, k, step) + + elif self.backbone == "twobranch": + recon_sample, recon_fre = self.restore_fn(faded_recon_sample, aux, step) + recon_sample = recon_sample #// 2 + recon_fre // 2 + + x0_list.append(recon_sample) + + if self.degradation_type == 'fade': + if self.sampling_routine == 'default': + for i in range(times - 1): + with torch.no_grad(): + if rand_kernels is not None: + recon_sample = torch.stack([rand_kernels[:, i].to(sample_device), + rand_kernels[:, i].to(sample_device), + rand_kernels[:, i].to(sample_device)], 1) * recon_sample + else: + recon_sample = self.fade_kernels[i].to(sample_device) * recon_sample + faded_recon_sample = recon_sample + + elif self.sampling_routine == 'x0_step_down': + for i in range(t): + with torch.no_grad(): + recon_sample_sub_1 = recon_sample + if rand_kernels is not None: + + recon_sample = apply_ksu_kernel(recon_sample, rand_kernels[i], params_dict) + else: + recon_sample = apply_ksu_kernel(recon_sample, self.kspace_kernels[i], params_dict) + + faded_recon_sample = faded_recon_sample - recon_sample + recon_sample_sub_1 + + elif self.degradation_type == 'kspace': + # faded_recon_sample = recon_sample + if self.sampling_routine == 'default': + for i in range(t - 1): + with torch.no_grad(): + if rand_kernels is not None: + k = torch.stack([rand_kernels[:, i]], 1) + recon_sample = apply_ksu_kernel(recon_sample, k, params_dict) + else: + recon_sample = apply_ksu_kernel(recon_sample, self.kspace_kernels[i], params_dict) + + faded_recon_sample = recon_sample + + elif self.sampling_routine == 'x0_step_down': + for i in range(t): + with torch.no_grad(): + recon_sample_sub_1 = recon_sample + if rand_kernels is not None: + k = torch.stack([rand_kernels[:, i]], 1) + recon_sample = apply_ksu_kernel(recon_sample, k, params_dict) + else: + recon_sample = apply_ksu_kernel(recon_sample, self.kspace_kernels[i], params_dict) + + faded_recon_sample = faded_recon_sample - recon_sample + recon_sample_sub_1 + + + xt_list.append(faded_recon_sample) + times -= 1 + + return x0_list, xt_list + + # Train + def q_sample(self, x_start, t, params_dict=None, use_fre_noise=False): + x = x_start + + with torch.no_grad(): + k = torch.stack([self.kspace_kernels[t]], 1) + x = apply_ksu_kernel(x, k, params_dict, use_fre_noise=use_fre_noise) # self.use_fre_noise + + return x, k + + + def reconstruct_loss(self, x_start, x_recon): + if self.loss_type == 'l1': + loss = (x_start - x_recon).abs().mean() + elif self.loss_type == 'l2': + loss = func.mse_loss(x_start, x_recon) + else: + raise NotImplementedError() + return loss + + + def gaussian_kernel(self, size: int, sigma: float): + """Generates a 2D Gaussian kernel.""" + x = torch.arange(size).float() - size // 2 + gauss = torch.exp(-x ** 2 / (2 * sigma ** 2)) + kernel = gauss[:, None] @ gauss[None, :] + kernel /= kernel.sum() + return kernel.cuda() + + def gaussian_blur(self, input_tensor, kernel_size: int, sigma: float): + """Applies Gaussian blur to a 4D tensor (N, C, H, W).""" + # Create Gaussian kernel + kernel = self.gaussian_kernel(kernel_size, sigma).unsqueeze(0).unsqueeze(0) + kernel = kernel.expand(input_tensor.size(1), 1, kernel_size, kernel_size) # For each channel + + # Pad the input tensor to avoid size reduction + padding = kernel_size // 2 + input_tensor = F.pad(input_tensor, (padding, padding, padding, padding), mode='reflect') + + # Apply convolution + blurred = F.conv2d(input_tensor, kernel, groups=input_tensor.size(1)) + return blurred + + def get_frequency_elements(self, x): + x_fft = torch.fft.rfft2(x, norm="ortho") + + # Perform FFT and compute magnitudes + x_mag = torch.clamp(torch.abs(x_fft), min=1e-8) + x_phase = torch.angle(x_fft) + + return x_mag, x_phase + + def get_fre_kl_loss(self, pred_spa, pred_fre, target, k): + # Flatten the elements + B = pred_spa.shape[0] + k = k.contiguous().view(B, -1) + pred_spa = pred_spa.view(B, -1) + pred_fre = pred_fre.view(B, -1) + target = target.view(B, -1) + + # minus max value + pred_spa = pred_spa - torch.max(pred_spa, dim=1, keepdim=True).values + pred_fre = pred_fre - torch.max(pred_fre, dim=1, keepdim=True).values + target = target - torch.max(target, dim=1, keepdim=True).values + + target_prob = F.softmax(target, dim=1) + + ele_num = 2 + target_all_prob = torch.cat([target_prob for ii in range(ele_num)], dim=0) + k_mask = torch.cat([k for ii in range(ele_num)], dim=0) + k_total = torch.sum(k_mask) + + pred_all = torch.cat([pred_spa, pred_fre]) # 4B + pred_all = F.log_softmax(pred_all, dim=1) + + + # consistency_loss + # get probability + pred_spa_prob = F.softmax(pred_spa, dim=1) + pred_fre_prob = F.softmax(pred_fre, dim=1) + pred_avg_prob = 1.0 / ele_num * (pred_spa_prob + pred_fre_prob) # 2 B + pred_avg_prob = torch.cat([pred_avg_prob for ii in range(ele_num)], dim=0).clone().detach() + + kl_consist_loss = (self.kl_loss(pred_all, pred_avg_prob) * k_mask).sum() / k_total + + return kl_consist_loss # + kl_loss + + + def frequency_consistency_loss(self, pred_spa, pred_fre, target, k): + ''' + KL-term, enforcing conditional distribution remains unchanged regardless of interventions applied + ''' + + W = pred_spa.shape[-1] + half_W = W // 2 + 1 + k = (1 - k.to(pred_spa.device)) # negative mask + k = k[..., :half_W] + + pred_spa_mag, pred_spa_pha = self.get_frequency_elements(pred_spa) + pred_fre_mag, pred_fre_pha = self.get_frequency_elements(pred_fre) + target_mag, target_pha = self.get_frequency_elements(target) + + mag_kl_loss = self.get_fre_kl_loss(pred_spa_mag, pred_fre_mag, target_mag, k) + pha_kl_loss = self.get_fre_kl_loss(pred_spa_pha, pred_fre_pha, target_pha, k) + + # print("mag loss=", mag_kl_loss, "pha loss=", pha_kl_loss) + return mag_kl_loss + pha_kl_loss + + + def p_losses(self, x_start, aux, t, params_dict): + self.debug_print = False + self.debug_time = False + + start_time = time.time() + + x_start_golden = x_start.clone() + x_mix, k = self.q_sample(x_start=x_start, t=t, params_dict=params_dict) + + # gaussian blur for x_mix + # if np.random.rand() > 0.5: + # x_mix = self.gaussian_blur( + # x_mix, + # kernel_size=int(torch.randint(1, 9, (1,)).item() * 2 + 1), # Ensure odd kernel size + # sigma=torch.abs(torch.randn(1) * 3.0).item() # Ensure sigma is positive + # ) + + # Add gaussian noise + # sigma = 0.1 * torch.abs(torch.rand(1)).item() # Standard Deviation + # x_mix = x_mix + torch.randn_like(x_mix) * sigma + # aux = aux + torch.randn_like(x_mix) * sigma + + x_mix = x_mix.detach() + aux = aux.detach() + x_start_golden = x_start_golden.detach() + k = k.detach() + + if self.debug_time: + print("--------------------") + print("sample time=", time.time() - start_time) # 0.02s ~ 0.03s + + + if self.backbone == 'unet': + x_recon = self.restore_fn(x_mix, t) + loss = self.reconstruct_loss(x_start_golden, x_recon) + + if self.use_lpips: + lpips_weight = 0.1 + lpips_loss = lpips_weight * self.lpips(x_recon, x_start_golden).mean() + loss += lpips_loss + + + if self.use_fre_loss: # NAN + fft_weight = 0.01 + + fre_loss = fft_weight * self.amploss(x_recon, x_start_golden, k) + loss += fre_loss + + + elif self.backbone == 'twounet': + x_recon = self.restore_fn(x_mix, aux, k, t) + loss = self.reconstruct_loss(x_start_golden, x_recon) * 5.0 + + # LPIPS + if self.use_lpips: + lpips_weight = 0.1 + lpips_loss = self.lpips(x_recon, x_start_golden).mean() + loss += lpips_weight * lpips_loss + + + if self.use_fre_loss: # NAN + fft_weight = 0.1 + amp = self.amploss(x_recon, x_start_golden, k) + loss += fft_weight * amp + + + elif self.backbone == 'twobranch': + + if self.fp16: + with autocast(): + x_recon, x_recon_fre = self.restore_fn(x_mix, aux, t) + else: + x_recon, x_recon_fre = self.restore_fn(x_mix, aux, t) + if self.debug_time: + print("restore_fn time=", time.time() - start_time) + + # img_mean = params_dict['img_mean'].cuda().view(-1, 1, 1, 1) + # img_std = params_dict['img_std'].cuda().view(-1, 1, 1, 1) + # x_start_golden = x_start_golden * img_std + img_mean # 0 - 1 + # x_recon = x_recon * img_std + img_mean + # x_recon_fre = x_recon_fre * img_std + img_mean + + loss_spatial = self.reconstruct_loss(x_start_golden, x_recon) + loss_freq = self.reconstruct_loss(x_start_golden, x_recon_fre) + loss = loss_spatial + loss_freq + + if self.debug_time: + print("reconstruct_loss time=", time.time() - start_time) + + # LPIPS + if self.use_lpips: + lpips_weight = 0.1 + lpips_loss = lpips_weight * self.lpips(x_recon, x_start_golden).mean() + loss += lpips_loss + + if self.use_ssim: + ssim_weight = 0.1 + ssim_loss = 1.0 - self.ssim(x_recon, x_start_golden).mean() + loss += ssim_weight * ssim_loss + + if self.use_fre_loss: # NAN + fft_weight = 0.01 + + fre_loss = fft_weight * self.amploss(x_recon_fre, x_start_golden, k) + loss += fre_loss + + # fre_loss = fft_weight * self.amploss(x_recon, x_start_golden, k) + # loss += fre_loss + + if self.use_kl: + amp_fre = fft_weight * self.frequency_consistency_loss(x_recon, x_recon_fre, x_start_golden, k) + loss += amp_fre + + if self.debug_time: + print("fre loss time=", time.time() - start_time) + print("--------------------") + + if np.random.rand() < 0.001: + print("----------------------------------------\n" + "loss_spatial:", loss_spatial.item(), + "loss_freq:", loss_freq.item(), + # "lpips_loss:", lpips_loss.item(), + # "ssim_loss", ssim_loss.item(), + "fre_loss:", fre_loss.item()) + + print("----------------------------------------\n" + "x_recon:", x_recon.min().item(), x_recon.max().item(), + "x_recon_fre:", x_recon_fre.min().item(), x_recon_fre.max().item(), + "x_start_golden:", x_start_golden.min().item(), x_start_golden.max().item()) + + return loss + + + def forward(self, x1, x2=None, params_dict=None, *args, **kwargs): + b, c, h, w, device, img_size, = *x1.shape, x1.device, self.image_size + assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + + loss = self.p_losses(x1, x2, t, params_dict, *args, **kwargs) + return loss + + + +class Trainer(nn.Module): + def __init__( + self, + diffusion_model, + folder, + mode, + *, + norm="mean_std", + ema_decay=0.995, + image_size=128, + train_batch_size=32, + train_lr=2e-5, + train_num_steps=700000, + gradient_accumulate_every=2, + fp16=False, + step_start_ema=2000, + update_ema_every=100, + save_and_sample_every=1000, + results_folder='./results', + load_path=None, + dataset=None, + shuffle=True, + domain=None, + aux_modality=None, + num_channels=1, + debug=False, + ): + super().__init__() + + self.mode = mode + self.model = diffusion_model + self.ema = EMA(ema_decay) + self.ema_model = copy.deepcopy(self.model) + self.update_ema_every = update_ema_every + + self.step_start_ema = step_start_ema + self.save_and_sample_every = save_and_sample_every if not debug else 10 + + self.batch_size = train_batch_size + self.image_size = diffusion_model.module.image_size + self.gradient_accumulate_every = gradient_accumulate_every + self.train_num_steps = train_num_steps + self.input_normalize = norm + + + if dataset == 'train': + print(dataset, "DA used") + self.ds = Dataset_Aug1(folder, image_size) + + elif dataset.lower() == 'brain': + print(dataset, "Brain DA used", "mode=", mode) + # mode, base_dir, image_size, nclass, domains, aux_modality, + self.ds = BrainDataset(mode, folder, image_size, 4, + debug=debug, + domains=domain, + num_channels=num_channels, + aux_modality=aux_modality) # mode, base_dir, domains: + + + + elif dataset.lower() == 'fsm_brain': + from dataset.BRATS_dataloader import Hybrid, ToTensor, RandomPadCrop, AddNoise + from torchvision import transforms + train_data_path = folder + + self.ds = Hybrid(split=mode, SNR=0, + # transform=transforms.Compose([RandomPadCrop(), ToTensor(), AddNoise()]), + transform=transforms.Compose([RandomPadCrop(), ToTensor()]), + base_dir=train_data_path, input_normalize=norm, + image_size=image_size, debug=debug) + + + self.test_ds = Hybrid(split="test", SNR=0, + # transform=transforms.Compose([RandomPadCrop(), ToTensor(), AddNoise()]), + transform=transforms.Compose([RandomPadCrop(), ToTensor()]), + base_dir=train_data_path, input_normalize=norm, + image_size=image_size, debug=debug) + + elif dataset.lower() == 'fsm_knee': + from dataset.fastmri import SliceDataset + root_path = folder + transforms = None + + # self.ds = build_dataset(args, mode='train') + self.ds = SliceDataset(root_path, transforms, 'singlecoil', + image_size=image_size, input_normalize=norm, + sample_rate=1, mode=mode, debug=debug) + + self.test_ds = SliceDataset(root_path, transforms, 'singlecoil', + image_size=image_size, input_normalize=norm, + sample_rate=1, mode="test", debug=debug) + + elif dataset.lower() == 'fsm_m4raw': + pass + + + + else: + print(dataset) + self.ds = Dataset(folder, image_size) + + self.train_batch_size = train_batch_size + self.batch_size = train_batch_size if mode == 'train' else 1 + + self.dl = cycle( + data.DataLoader(self.ds, + batch_size=train_batch_size if mode == 'train' else 1, + shuffle=(mode == "train"), + pin_memory=True, + num_workers=16, + drop_last=True)) + + self.test_dl = cycle( + data.DataLoader(self.ds, + batch_size=train_batch_size, + shuffle="test", + pin_memory=True, + num_workers=16, + drop_last=False)) + + self.opt = AdamW(list(self.model.module.restore_fn.parameters()), + lr=train_lr, + betas=(0.9, 0.999), + weight_decay=1e-4) + + self.scheduler = lr_scheduler.StepLR(self.opt, step_size=10000, gamma=0.5) + self.step = 0 + + # assert not fp16 or fp16 and APEX_AVAILABLE, 'Apex must be installed for mixed precision training on' + + self.fp16 = fp16 + os.makedirs(results_folder, exist_ok=True) + self.results_folder = Path(results_folder) + + np.save(str(self.results_folder / "kspace_kernels.npy"), self.model.module.kspace_kernels.cpu()) + + self.lpips = LPIPS().eval().cuda() + + self.reset_parameters() + + if load_path is not None: + self.load(load_path) + kspace_npy = load_path.replace('model.pt', 'kspace_kernels.npy') + self.model.module.kspace_kernels = torch.from_numpy(np.load(kspace_npy)).to(self.model.module.kspace_kernels.device) + self.ema_model.module.kspace_kernels = self.model.module.kspace_kernels + + def reset_parameters(self): + self.ema_model.load_state_dict(self.model.state_dict()) + + def step_ema(self): + if self.step < self.step_start_ema: + self.reset_parameters() + return + self.ema.update_model_average(self.ema_model, self.model) + + def save(self): + model_data = { + 'step': self.step, + 'model': self.model.state_dict(), + 'ema': self.ema_model.state_dict() + } + save_name = str(self.results_folder / f'model.pt') + print("Save_name=", save_name) + torch.save(model_data, save_name) + + def load(self, load_path): + print("Loading : ", load_path) + model_data = torch.load(load_path) + + self.step = model_data['step'] + self.model.load_state_dict(model_data['model']) + self.ema_model.load_state_dict(model_data['ema']) + print("Loading complete") + + @staticmethod + def add_title(path, title): + img1 = cv2.imread(path) + + black = [0, 0, 0] + constant = cv2.copyMakeBorder(img1, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=black) + height = 20 + violet = np.zeros((height, constant.shape[1], 3), np.uint8) + violet[:] = (255, 0, 180) + + vcat = cv2.vconcat((violet, constant)) + + font = cv2.FONT_HERSHEY_SIMPLEX + + cv2.putText(vcat, str(title), (violet.shape[1] // 2, height - 2), font, 0.5, (0, 0, 0), 1, 0) + cv2.imwrite(path, vcat) + + # + def calculate_metrics(self, all_images, og_img): + img_ = all_images.cpu() #.permute(0, 2, 3, 1).numpy()[..., 0] + og_img_ = og_img.cpu() #.permute(0, 2, 3, 1).numpy()[..., 0] + + # print("img_=", img_.shape, "og_img_=", og_img_.shape) # img_= torch.Size([4, 1, 240, 240]) og_img_= torch.Size([4, 1, 240, 240]) + + # B, C, H, W + # ssim = StructuralSimilarityIndexMeasure(data_range=255) + ssims_ = [] + for (img, og_img) in zip(img_, og_img_): + img_np = img.squeeze().numpy() # Convert to 2D + og_img_np = og_img.squeeze().numpy() # Convert to 2D + + # Compute SSIM for each pair of images + ssim_ = structural_similarity(og_img_np, img_np) + ssims_.append(ssim_) + + ssim_ = np.mean(ssims_) + + psnr_ = peak_signal_noise_ratio(og_img_.numpy(), img_.numpy()).mean() + nmse_ = nmse(og_img_.numpy(), img_.numpy()).mean() + + + return ssim_, psnr_, nmse_ + + # pip install pytorch-fid + + # (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8 + def calculate_metrics_3d(self, all_images, og_img): + all_images = torch.clamp(all_images, 1e-6, 1) + og_img = torch.clamp(og_img, 1e-6, 1) + + # img_ = torch.Size([5, 1, 64, 64]) torch.Size([5, 1, 64, 64] + all_images_new = all_images.unsqueeze(1).repeat(1, 3, 1, 1) + og_img_new = og_img.unsqueeze(1).repeat(1, 3, 1, 1) + + cal_fid = False + if cal_fid: + # (N, 3, 299, 299) + fid_value = calculate_fid(all_images_new.cpu().numpy(), og_img_new.cpu().numpy(), + use_multiprocessing=False, batch_size=og_img.shape[-1]) + # (N, 3, C, 256, 256) + fid_value_3d = calculate_fid_3d(all_images_new.cpu().numpy(), og_img_new.cpu().numpy(), + use_multiprocessing=False, batch_size=og_img.shape[-1]) + else: + fid_value = 0 + fid_value_3d = 0 + + # B, C, H, W + lpips = self.lpips(all_images_new.cuda(), og_img_new.cuda()).mean().item() + + # H, W, C + img_ = all_images.cpu().unsqueeze(0) # .permute(1, 2, 0).numpy() + og_img_ = og_img.cpu().unsqueeze(0) #.permute(1, 2, 0).numpy() #.numpy() + + # 0-1, H, W, C + ssim = StructuralSimilarityIndexMeasure(data_range=1.0) + ssim_ = ssim(og_img_, img_).mean() + psnr_ = psnr(og_img_.numpy(), img_.numpy(), data_range=1.0).mean() + + return ssim_, psnr_, fid_value, fid_value_3d, lpips + + # Evaluate all + def test_data_dict(self, tag, data_dict, batches, routine): + print("=== Test tag: ", tag) + + og_img = data_dict['img'].cuda() + aux = data_dict['aux'].cuda() + img_mean = data_dict['img_mean'].cuda() + img_std = data_dict['img_std'].cuda() + aux_mean = data_dict['aux_mean'].cuda() + aux_std = data_dict['aux_std'].cuda() + params_dict = {'img_mean': img_mean, 'img_std': img_std, 'aux_mean': aux_mean, 'aux_std': aux_std} + + + # print("input shape minmax, og_img:", og_img.min().item(), og_img.max().item(), + # "aux:", aux.min().item(), aux.max().item()) + + xt, direct_recons, all_images, return_k, all_recons, all_recons_fre, all_masks = ( + self.ema_model.module.sample( + batch_size=batches, + faded_recon_sample=og_img, + aux=aux, params_dict=params_dict, + sample_routine=routine)) + + img_std = img_std.view(-1, 1, 1, 1) + img_mean = img_mean.view(-1, 1, 1, 1) + aux_std = aux_std.view(-1, 1, 1, 1) + aux_mean = aux_mean.view(-1, 1, 1, 1) + + og_img = og_img * img_std + img_mean + + _min = og_img.min() + _max = og_img.max() + + og_img = (og_img - _min) / (_max - _min) + all_images = all_images * img_std + img_mean + all_images = (all_images - _min) / (_max - _min) + all_recons = all_recons * img_std + img_mean + all_recons = (all_recons - _min) / (_max - _min) + all_images = all_recons[-1] + + direct_recons = direct_recons * img_std + img_mean + direct_recons = (direct_recons - _min) / (_max - _min) + xt = xt * img_std + img_mean + xt = (xt - _min) / (_max - _min) + + aux = aux * aux_std + aux_mean + aux = (aux - aux.min()) / (aux.max() - aux.min()) + + # print("----------------------------------------\n" + # "all_recons:", all_recons.min().item(), all_recons.max().item(), + # "all_images:", all_images.min().item(), all_images.max().item(), + # "og_img:", og_img.min().item(), og_img.max().item(), + # "direct_recons:", direct_recons.min().item(), direct_recons.max().item()) + + all_recons = torch.clamp(all_recons, 1e-6, 1).mul(255).to(torch.int8) + direct_recons = torch.clamp(direct_recons, 1e-6, 1).mul(255).to(torch.int8) + all_images = torch.clamp(all_images, 1e-6, 1).mul(255).to(torch.int8) + og_img = torch.clamp(og_img, 1e-6, 1).mul(255).to(torch.int8) + + + + # 24, 1, 128, 128 + # Calculate SSIM and PSNR, LPIPS + ssims = [] + psnrs = [] + + ssim_, psnr_, nmse_ = self.calculate_metrics(og_img, direct_recons) + lpips = self.lpips(direct_recons.float()/255, og_img.float()/255).mean().item() + print(f"=== first step Metrics {routine}: SSIM: ", ssim_, " PSNR: ", psnr_, + " LPIPS: ", lpips, " NMSE: ", nmse_) + + for im in all_recons: + im = torch.clamp(im, 1e-6, 1).mul(255).to(torch.int8) + ssim_, psnr_, nmse_ = self.calculate_metrics(og_img, im) + ssims.append(ssim_) + psnrs.append(psnr_) + + + ssim_, psnr_, nmse_ = self.calculate_metrics(og_img, all_images) + lpips = self.lpips(all_images.float()/255, og_img.float()/255).mean().item() + + print(f"=== Final Metrics {routine}: SSIM: ", ssim_, " PSNR: ", psnr_, + " LPIPS: ", lpips, " NMSE: ", nmse_) + + os.makedirs(self.results_folder, exist_ok=True) + + # utils.save_image(xt, str(self.results_folder / f'{self.step}-xt-Noise.png'), nrow=6) + # utils.save_image(all_images, str(self.results_folder / f'{self.step}-full_recons.png'), + # nrow=6) + # utils.save_image(direct_recons, + # str(self.results_folder / f'{self.step}-sample-direct_recons.png'), nrow=6) + # utils.save_image(og_img, str(self.results_folder / f'{self.step}-img.png'), nrow=6) + # utils.save_image(aux, str(self.results_folder / f'{self.step}-aux.png'), nrow=6) + + # plot ssim and psnr in two sub plots same canvas parallel + import matplotlib.pyplot as plt + fig, axs = plt.subplots(2, figsize=(8, 6), dpi=100, sharex=True) + axs[0].plot(ssims, marker='o', linestyle='-', color='blue', label='SSIM') + axs[0].set_title('Structural Similarity Index (SSIM)', fontsize=14) + axs[0].set_ylabel('SSIM', fontsize=12) + axs[0].grid(True, linestyle='--', alpha=0.7) + axs[0].legend() + + # PSNR plot + axs[1].plot(psnrs, marker='o', linestyle='-', color='green', label='PSNR') + axs[1].set_title('Peak Signal-to-Noise Ratio (PSNR)', fontsize=14) + axs[1].set_xlabel('Iterations', fontsize=12) + axs[1].set_ylabel('PSNR (dB)', fontsize=12) + axs[1].grid(True, linestyle='--', alpha=0.7) + axs[1].legend() + + fig.tight_layout() + + plt.savefig(str(self.results_folder / f'{self.step}-metrics-{routine}.png')) + + return_k = return_k.cuda() + + + combine = torch.cat((return_k, + xt, + direct_recons, + all_images, + all_recons[-1].to(all_images.device), + og_img, aux), 2) + + utils.save_image(combine, str(self.results_folder / f'{self.step}-combine-{routine}.png'), nrow=6) + + # all_recon = all_recons[:, 0] # 50, 1, 128, 128 + # Ensure all_recons is on the CPU + + all_recons = torch.cat(list(all_recons), dim=-1).cpu() + all_masks = all_masks.cpu() + # all_masks = torch.cat(list(all_masks), dim=-1) + + s = all_recons.shape[-2] + repeats = all_recons.shape[3] // og_img.shape[3] # Calculate repeat factor + # tensor_small = tensor_small.repeat(1, 1, 1, repeats) + og_img = og_img.cpu() + all_recons_residual = all_recons - og_img.repeat(1, 1, 1, repeats) + all_recons_residual = (all_recons_residual - all_recons_residual.min()) / (all_recons_residual.max() - all_recons_residual.min()) + + # before and after residual + all_recons_residual_2 = all_recons[:, :, :, s:] - all_recons[:, :, :, :-s] + all_recons_residual_2 = (all_recons_residual_2 - all_recons_residual_2.min()) / (all_recons_residual_2.max() - all_recons_residual_2.min()) + padding = torch.zeros_like(all_recons_residual[:, :, :, :s // 2]) + all_recons_residual_2 = torch.cat([padding, all_recons_residual_2, padding], dim=-1) + + all_recons = torch.cat([all_recons, all_recons_residual_2, all_recons_residual], dim=-2) + + utils.save_image(all_recons, str(self.results_folder / f'{self.step}-all_recons-{routine}.png'), + nrow=1) + # utils.save_image(all_masks, str(self.results_folder / f'{self.step}-all_masks-{routine}.png'), + # nrow=1) + + # acc_loss = acc_loss / (self.save_and_sample_every + 1) + print(f'Mean of last {self.step}: save to :', str(self.results_folder / f'{self.step}-combine.png')) + + # acc_loss = 0 + + + def train(self): + backwards = partial(loss_backwards, self.fp16) + # writer = SummaryWriter() + + acc_loss = 0 + start_time = time.time() + + while self.step < self.train_num_steps: + d_time = time.time() + self.opt.zero_grad() + u_loss = 0 + + if (self.step + 1 )% 20000 == 0: + self.ds.update_chunk() + self.dl = cycle( + data.DataLoader(self.ds, + batch_size=self.train_batch_size, + shuffle=True, + pin_memory=True, + num_workers=16, + drop_last=True)) + self.test_dl = cycle( + data.DataLoader(self.test_ds, + batch_size=self.train_batch_size, + shuffle=False, + pin_memory=True, + num_workers=16, + drop_last=False)) + + self.debug_time = False + + for i in range(self.gradient_accumulate_every): + + last_model_state = self.model.state_dict() + optimizer_state = self.opt.state_dict() + + data_dict = next(self.dl) + if self.debug_time: + print("Data loading time=", time.time() - d_time) # Very slow, 0.5 s second on bask, 0.001 on local + + img = data_dict['img'].cuda() + aux = data_dict['aux'].cuda() + img_mean = data_dict['img_mean'].cuda() + img_std = data_dict['img_std'].cuda() + aux_mean = data_dict['aux_mean'].cuda() + aux_std = data_dict['aux_std'].cuda() + params_dict = {'img_mean': img_mean, 'img_std': img_std, 'aux_mean': aux_mean, 'aux_std': aux_std} + + + loss = torch.mean(self.model(img, aux, params_dict)) + if self.debug_time: + print("Model iter=", self.step, " time=", time.time() - d_time) + + if torch.isnan(loss).any(): + print(f"NaN encountered in step {self.step}. Reverting model.") + self.model.load_state_dict(last_model_state) # Revert model + self.opt.load_state_dict(optimizer_state) # Revert optimizer + continue # Skip the rest of this training step + if self.debug_time: + print("before loss=", time.time() - d_time) + u_loss += loss + loss.backward() + # backwards(loss / self.gradient_accumulate_every, self.opt) + if self.debug_time: + print("after loss=", time.time() - d_time) + + del img, aux, img_mean, img_std, aux_mean, aux_std, params_dict + + + + if (self.step + 1) % (min(self.train_num_steps // 100 + 1, 100)) == 0: + print('iteration %d [%.2f sec]: learning rate : %f loss : %f ' % ( + self.step + 1, time.time() - start_time, self.scheduler.get_lr()[0], loss.item())) + + # writer.add_scalar("Loss/train", loss.item(), self.step) + acc_loss = acc_loss + (u_loss.item() / self.gradient_accumulate_every) + + max_norm = 0.01 # Maximum norm for gradients + torch.nn.utils.clip_grad_norm_(self.model.module.restore_fn.parameters(), max_norm) + if self.debug_time: + print("Before Optimization time=", time.time() - d_time) + + if self.fp16: + scaler.step(self.opt) + self.scheduler.step() + scaler.update() + + else: + self.opt.step() + self.scheduler.step() + + if self.debug_time: + print("Optimization time=", time.time() - d_time) + + if self.step % self.update_ema_every == 0: + self.step_ema() + + + # TEST and SAVE + if self.step != 0 and (self.step + 1) % self.save_and_sample_every == 0: + batches = self.batch_size + data_dict = next(self.test_dl) # .cuda() + train_dict = next(self.dl) # .cuda() + + # 'default', "fre_progressive", "x0_step_down" + for routine in ['x0_step_down_fre']: + self.test_data_dict("Train", train_dict, batches, routine) + self.test_data_dict("Test", data_dict, batches, routine) + + self.ema_model.module.restore_fn.train() + self.save() + + self.step += 1 + clean_start = time.time() + + del data_dict + del u_loss + torch.cuda.empty_cache() + gc.collect() + if self.debug_time: + print("Clean time=", time.time() - clean_start) + + print("Iter time = ", time.time() - d_time, "total time = ", time.time() - start_time) + + print('training completed') + + def test_loader(self, sampling_routine): + """ + Computes patient-wise 3D for a dataset of MRI slices. + + """ + print("Starting testing with sampling routine: ", sampling_routine) + + model = self.model # self.ema_model # or self.model + model.eval() # Set model to evaluation mode + + # sampling_routine = ['default', 'x0_step_down', 'x0_step_down_fre', "fre_progressive"]: + num_timesteps = model.module.num_timesteps + patient_wise_pred = {} # num_timesteps + patient_wise_gt = [] + count = 1 + + while True: + batches = 1 #self.batch_size + data_dict = next(self.dl) # .cuda() + + + # Prediction + og_img = data_dict['img'].cuda() + aux = data_dict['aux'].cuda() + img_mean = data_dict['img_mean'].cuda() + img_std = data_dict['img_std'].cuda() + aux_mean = data_dict['aux_mean'].cuda() + aux_std = data_dict['aux_std'].cuda() + params_dict = {'img_mean': img_mean, 'img_std': img_std, 'aux_mean': aux_mean, 'aux_std': aux_std} + + xt, direct_recons, all_images, return_k, all_recons, all_recons_fre, all_masks = ( + model.module.sample( + batch_size=batches, + faded_recon_sample=og_img, + aux=aux, params_dict=params_dict, + sample_routine=sampling_routine)) + + img_std = img_std.view(-1, 1, 1, 1) + img_mean = img_mean.view(-1, 1, 1, 1) + aux_std = aux_std.view(-1, 1, 1, 1) + aux_mean = aux_mean.view(-1, 1, 1, 1) + + + print("before or_img shape: ", og_img.shape, og_img.min(), og_img.max()) + print("before direct_recons shape: ", direct_recons.shape, direct_recons.min(), direct_recons.max()) + # into a Normalized Image + direct_recons_norm = (direct_recons - direct_recons.mean()) / (direct_recons.std()) + all_images_norm = (all_images - all_images.mean()) / (all_images.std()) + + og_img = og_img * img_std + img_mean + _min = og_img.min() + _max = og_img.max() + + # og_img = (og_img - _min) / (_max - _min) # 0 - 1 + all_images = all_images * img_std + img_mean + # all_images = (all_images - _min) / (_max - _min) + all_images_norm = all_images_norm * img_std + img_mean + # all_images_norm = (all_images_norm - _min) / (_max - _min) + + all_recons = all_recons * img_std + img_mean + # all_recons = (all_recons - _min) / (_max - _min) + direct_recons = direct_recons * img_std + img_mean + # direct_recons = (direct_recons - _min) / (_max - _min) + + direct_recons_norm = direct_recons_norm * img_std + img_mean + # direct_recons_norm = (direct_recons_norm - _min) / (_max - _min) + + xt = xt * img_std + img_mean + # xt = (xt - _min) / (_max - _min) + + aux = aux * aux_std + aux_mean + # aux = (aux - aux.min()) / (aux.max() - aux.min()) + all_recons = all_recons.cpu() + + all_recons = torch.clamp(all_recons, 1e-6, 1).mul(255).to(torch.int8) + direct_recons = torch.clamp(direct_recons, 1e-6, 1).mul(255).to(torch.int8) + all_images = torch.clamp(all_images, 1e-6, 1).mul(255).to(torch.int8) + og_img = torch.clamp(og_img, 1e-6, 1).mul(255).to(torch.int8) + + # 24, 1, 128, 128 + # Calculate SSIM and PSNR, LPIPS + ssims = [] + psnrs = [] + + + + ssim_, psnr_, nmse_ = self.calculate_metrics(og_img, direct_recons) + lpips = self.lpips(direct_recons.float(), og_img.float()).mean().item() + print(f"=== first step Metrics {sampling_routine}: SSIM: ", ssim_, " PSNR: ", psnr_, " LPIPS: ", lpips, " NMSE: ", nmse_) + + for im in all_recons: + ssim_, psnr_, nmse_ = self.calculate_metrics(og_img, im) + ssims.append(ssim_) + psnrs.append(psnr_) + + ssim_, psnr_, nmse_ = self.calculate_metrics(og_img, all_images) + lpips = self.lpips(all_images.float(), og_img.float()).mean().item() + + print(f"=== Final Metrics {sampling_routine}: SSIM: ", ssim_, " PSNR: ", psnr_, " LPIPS: ", lpips, " NMSE: ", nmse_) + + + os.makedirs(self.results_folder, exist_ok=True) + + # utils.save_image(xt, str(self.results_folder / f'{self.step}-xt-Noise.png'), nrow=6) + # utils.save_image(all_images, str(self.results_folder / f'{self.step}-full_recons.png'), + # nrow=6) + # utils.save_image(direct_recons, + # str(self.results_folder / f'{self.step}-sample-direct_recons.png'), nrow=6) + # utils.save_image(og_img, str(self.results_folder / f'{self.step}-img.png'), nrow=6) + # utils.save_image(aux, str(self.results_folder / f'{self.step}-aux.png'), nrow=6) + + # plot ssim and psnr in two sub plots same canvas parallel + import matplotlib.pyplot as plt + fig, axs = plt.subplots(2, figsize=(8, 6), dpi=100, sharex=True) + axs[0].plot(ssims, marker='o', linestyle='-', color='blue', label='SSIM') + axs[0].set_title('Structural Similarity Index (SSIM)', fontsize=14) + axs[0].set_ylabel('SSIM', fontsize=12) + axs[0].grid(True, linestyle='--', alpha=0.7) + axs[0].legend() + + # PSNR plot + axs[1].plot(psnrs, marker='o', linestyle='-', color='green', label='PSNR') + axs[1].set_title('Peak Signal-to-Noise Ratio (PSNR)', fontsize=14) + axs[1].set_xlabel('Iterations', fontsize=12) + axs[1].set_ylabel('PSNR (dB)', fontsize=12) + axs[1].grid(True, linestyle='--', alpha=0.7) + axs[1].legend() + + fig.tight_layout() + + plt.savefig(str(self.results_folder / f'{count}-metrics-{sampling_routine}.png')) + + return_k = return_k.cuda() + + combine = torch.cat((return_k, + xt, + all_images, direct_recons, og_img, aux), 2) + + # utils.save_image(combine, str(self.results_folder / f'{self.step}-combine-{routine}.png'), nrow=6) + + # all_recon = all_recons[:, 0] # 50, 1, 128, 128 + # Ensure all_recons is on the CPU + + all_recons = torch.cat(list(all_recons), dim=-1) + all_masks = all_masks.cpu() + all_masks = torch.cat(list(all_masks), dim=-1) + + s = all_recons.shape[-2] + repeats = all_recons.shape[3] // og_img.shape[3] # Calculate repeat factor + # tensor_small = tensor_small.repeat(1, 1, 1, repeats) + og_img = og_img.cpu() + all_recons_residual = all_recons - og_img.repeat(1, 1, 1, repeats) + # all_recons[:, :, :, s:] + all_recons_residual_2 = all_recons[:, :, :, s:] - all_recons[:, :, :, :-s] + padding = torch.zeros_like(all_recons_residual[:, :, :, :s // 2]) + all_recons_residual_2 = torch.cat([padding, all_recons_residual_2, padding], dim=-1) + + all_recons = torch.cat([all_recons, all_recons_residual_2, all_recons_residual], dim=-2) + + # utils.save_image(all_recons, str(self.results_folder / f'{self.step}-all_recons-{routine}.png'), + # nrow=1) + count += 1 + + + def test_loader_3d(self, sampling_routine): + """ + Computes patient-wise 3D for a dataset of MRI slices. + + """ + print("Starting testing with sampling routine: ", sampling_routine) + + model = self.model # self.ema_model # or self.model + model.eval() # Set model to evaluation mode + + # sampling_routine = ['default', 'x0_step_down', 'x0_step_down_fre', "fre_progressive"]: + num_timesteps = model.module.num_timesteps + patient_wise_pred = {} # num_timesteps + patient_wise_gt = [] + + while True: + batches = 1 #self.batch_size + data_dict = next(self.dl) # .cuda() + if data_dict['is_start']: + patient_wise_pred = {} # num_timesteps of list + for i in range(num_timesteps): + patient_wise_pred[i] = [] + patient_wise_gt = [] + + # Original Input + og_img = data_dict['img'].cuda() + aux = data_dict['aux'].cuda() + img_mean = data_dict['img_mean'].cuda() + img_std = data_dict['img_std'].cuda() + aux_mean = data_dict['aux_mean'].cuda() + aux_std = data_dict['aux_std'].cuda() + params_dict = {'img_mean': img_mean, 'img_std': img_std, + 'aux_mean': aux_mean, 'aux_std': aux_std} + + # Prediction + xt, direct_recons, all_images, return_k, all_recons, all_recons_fre, all_masks = ( + model.module.sample( + batch_size=batches, + faded_recon_sample=og_img, + aux=aux, params_dict=params_dict, + sample_routine=sampling_routine)) + + # Handling the output + img_std = img_std.view(-1, 1, 1, 1) + img_mean = img_mean.view(-1, 1, 1, 1) + aux_std = aux_std.view(-1, 1, 1, 1) + aux_mean = aux_mean.view(-1, 1, 1, 1) + + og_img = og_img * img_std + img_mean + _min = og_img.min() + _max = og_img.max() + + og_img = (og_img - _min) / (_max - _min) + all_images = all_images * img_std + img_mean + all_images = (all_images - _min) / (_max - _min) + + all_recons = all_recons * img_std + img_mean + all_recons = (all_recons - _min) / (_max - _min) + direct_recons = direct_recons * img_std + img_mean + direct_recons = (direct_recons - _min) / (_max - _min) + xt = xt * img_std + img_mean + xt = (xt - _min) / (_max - _min) + + aux = aux * aux_std + aux_mean + aux = (aux - aux.min()) / (aux.max() - aux.min()) + all_recons = all_recons.cpu() + + # Save to list + patient_wise_gt.append(og_img) + for i in range(num_timesteps): + patient_wise_pred[i].append(all_recons[i]) + + if data_dict['is_end']: # or len(patient_wise_gt) >=10: # TODO + # 24, 1, 128, 128 + # Calculate SSIM and PSNR, LPIPS + patient_gt = torch.cat(patient_wise_gt, dim=0).squeeze() # C, H, W + + ssims, psnrs, lpips, fid, fid_3d = [], [], [], [], [] + + for i in range(num_timesteps): + patient_pred = torch.cat(patient_wise_pred[i], dim=0).squeeze() + + ssim_, psnr_, fid_, fid3d_, lpips_ = self.calculate_metrics_3d(patient_gt, patient_pred) + print("time step: ", i, "SSIM: ", ssim_, " PSNR: ", psnr_, " LPIPS: ", lpips_, " FID: ", fid_, " FID_3D: ", fid3d_) + + lpips.append(lpips_) + fid.append(fid_) + fid_3d.append(fid3d_) + ssims.append(ssim_) + psnrs.append(psnr_) + + print(f"=== Final Metrics {sampling_routine}: SSIM: ", ssim_, " PSNR: ", psnr_, " LPIPS: ", lpips_) + + + file_id = data_dict['file_id'][0].split("/")[-1] + + os.makedirs(self.results_folder, exist_ok=True) + + import matplotlib.pyplot as plt + fig, axs = plt.subplots(5, figsize=(8, 6), dpi=100, sharex=True) + axs[0].plot(ssims, marker='o', linestyle='-', color='blue', label='SSIM') + axs[0].set_title('Structural Similarity Index (SSIM)', fontsize=14) + axs[0].set_ylabel('SSIM', fontsize=12) + axs[0].grid(True, linestyle='--', alpha=0.7) + axs[0].legend() + + # PSNR plot + axs[1].plot(psnrs, marker='o', linestyle='-', color='green', label='PSNR') + axs[1].set_title('Peak Signal-to-Noise Ratio (PSNR)', fontsize=14) + axs[1].set_xlabel('Iterations', fontsize=12) + axs[1].set_ylabel('PSNR (dB)', fontsize=12) + axs[1].grid(True, linestyle='--', alpha=0.7) + axs[1].legend() + + # LPIPS plot + axs[2].plot(lpips, marker='o', linestyle='-', color='red', label='LPIPS') + axs[2].set_title('LPIPS', fontsize=14) + axs[2].set_xlabel('Iterations', fontsize=12) + axs[2].set_ylabel('LPIPS', fontsize=12) + axs[2].grid(True, linestyle='--', alpha=0.7) + axs[2].legend() + + # FID plot + axs[3].plot(fid, marker='o', linestyle='-', color='orange', label='FID') + axs[3].set_title('FID', fontsize=14) + axs[3].set_xlabel('Iterations', fontsize=12) + axs[3].set_ylabel('FID', fontsize=12) + axs[3].grid(True, linestyle='--', alpha=0.7) + axs[3].legend() + + axs[4].plot(fid_3d, marker='o', linestyle='-', color='purple', label='FID_3D') + axs[4].set_title('FID_3D', fontsize=14) + axs[4].set_xlabel('Iterations', fontsize=12) + axs[4].set_ylabel('FID_3D', fontsize=12) + axs[4].grid(True, linestyle='--', alpha=0.7) + axs[4].legend() + + fig.tight_layout() + save = str(self.results_folder / f'{file_id}-metrics-{sampling_routine}.png') + plt.savefig(save) + + print("Save metrics to ", save) + + + + + def test_from_data(self, extra_path, s_times=None): + batches = self.batch_size + og_img = next(self.dl).cuda() + x0_list, xt_list = self.ema_model.module.all_sample(batch_size=batches, faded_recon_sample=og_img, times=s_times) + + og_img = (og_img + 1) * 0.5 + utils.save_image(og_img, str(self.results_folder / f'og-{extra_path}.png'), nrow=6) + + frames_t = [] + frames_0 = [] + + for i in range(len(x0_list)): + print(i) + + x_0 = x0_list[i] + x_0 = (x_0 + 1) * 0.5 + utils.save_image(x_0, str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'), nrow=6) + self.add_title(str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'), str(i)) + frames_0.append(imageio.imread(str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'))) + + x_t = xt_list[i] + all_images = (x_t + 1) * 0.5 + utils.save_image(all_images, str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'), nrow=6) + self.add_title(str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'), str(i)) + frames_t.append(imageio.imread(str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'))) + + imageio.mimsave(str(self.results_folder / f'Gif-{extra_path}-x0.gif'), frames_0) + imageio.mimsave(str(self.results_folder / f'Gif-{extra_path}-xt.gif'), frames_t) + + def test_with_mixup(self, extra_path): + batches = self.batch_size + og_img_1 = next(self.dl).cuda() + og_img_2 = next(self.dl).cuda() + og_img = (og_img_1 + og_img_2) / 2 + + x0_list, xt_list = self.ema_model.module.all_sample(batch_size=batches, faded_recon_sample=og_img) + + og_img_1 = (og_img_1 + 1) * 0.5 + utils.save_image(og_img_1, str(self.results_folder / f'og1-{extra_path}.png'), nrow=6) + + og_img_2 = (og_img_2 + 1) * 0.5 + utils.save_image(og_img_2, str(self.results_folder / f'og2-{extra_path}.png'), nrow=6) + + og_img = (og_img + 1) * 0.5 + utils.save_image(og_img, str(self.results_folder / f'og-{extra_path}.png'), nrow=6) + + frames_t = [] + frames_0 = [] + + for i in range(len(x0_list)): + print(i) + x_0 = x0_list[i] + x_0 = (x_0 + 1) * 0.5 + utils.save_image(x_0, str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'), nrow=6) + self.add_title(str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'), str(i)) + frames_0.append(Image.open(str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'))) + + x_t = xt_list[i] + all_images = (x_t + 1) * 0.5 + utils.save_image(all_images, str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'), nrow=6) + self.add_title(str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'), str(i)) + frames_t.append(Image.open(str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'))) + + frame_one = frames_0[0] + frame_one.save(str(self.results_folder / f'Gif-{extra_path}-x0.gif'), format="GIF", append_images=frames_0, + save_all=True, duration=100, loop=0) + + frame_one = frames_t[0] + frame_one.save(str(self.results_folder / f'Gif-{extra_path}-xt.gif'), format="GIF", append_images=frames_t, + save_all=True, duration=100, loop=0) + + def test_from_random(self, extra_path): + batches = self.batch_size + og_img = next(self.dl).cuda() + og_img = og_img * 0.9 + x0_list, xt_list = self.ema_model.module.all_sample(batch_size=batches, faded_recon_sample=og_img) + + og_img = (og_img + 1) * 0.5 + utils.save_image(og_img, str(self.results_folder / f'og-{extra_path}.png'), nrow=6) + + frames_t_names = [] + frames_0_names = [] + + for i in range(len(x0_list)): + print(i) + + x_0 = x0_list[i] + x_0 = (x_0 + 1) * 0.5 + utils.save_image(x_0, str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'), nrow=6) + self.add_title(str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'), str(i)) + frames_0_names.append(str(self.results_folder / f'sample-{i}-{extra_path}-x0.png')) + + x_t = xt_list[i] + all_images = (x_t + 1) * 0.5 + utils.save_image(all_images, str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'), nrow=6) + self.add_title(str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'), str(i)) + frames_t_names.append(str(self.results_folder / f'sample-{i}-{extra_path}-xt.png')) + + frames_0 = [] + frames_t = [] + for i in range(len(x0_list)): + print(i) + frames_0.append(imageio.imread(frames_0_names[i])) + frames_t.append(imageio.imread(frames_t_names[i])) + + imageio.mimsave(str(self.results_folder / f'Gif-{extra_path}-x0.gif'), frames_0) + imageio.mimsave(str(self.results_folder / f'Gif-{extra_path}-xt.gif'), frames_t) + + def controlled_direct_reconstruct(self, extra_path): + batches = self.batch_size + torch.manual_seed(0) + og_img = next(self.dl).cuda() + xt, direct_recons, all_images = self.ema_model.module.sample(batch_size=batches, faded_recon_sample=og_img) + + og_img = (og_img + 1) * 0.5 + utils.save_image(og_img, str(self.results_folder / f'sample-og-{extra_path}.png'), nrow=6) + + all_images = (all_images + 1) * 0.5 + utils.save_image(all_images, str(self.results_folder / f'sample-recon-{extra_path}.png'), nrow=6) + + direct_recons = (direct_recons + 1) * 0.5 + utils.save_image(direct_recons, str(self.results_folder / f'sample-direct_recons-{extra_path}.png'), nrow=6) + + xt = (xt + 1) * 0.5 + utils.save_image(xt, str(self.results_folder / f'sample-xt-{extra_path}.png'), + nrow=6) + + self.save() + + def fid_distance_decrease_from_manifold(self, fid_func, start=0, end=1000): + + all_samples = [] + dataset = self.ds + + print(len(dataset)) + for idx in range(len(dataset)): + img = dataset[idx] + img = torch.unsqueeze(img, 0).cuda() + if idx > start: + all_samples.append(img[0]) + if idx % 1000 == 0: + print(idx) + if end is not None: + if idx == end: + print(idx) + break + + all_samples = torch.stack(all_samples) + blurred_samples = None + original_sample = None + deblurred_samples = None + direct_deblurred_samples = None + + sanity_check = blurred_samples + + cnt = 0 + while cnt < all_samples.shape[0]: + og_x = all_samples[cnt: cnt + 50] + og_x = og_x.cuda() + og_x = og_x.type(torch.cuda.FloatTensor) + og_img = og_x + print(og_img.shape) + x0_list, xt_list = self.ema_model.module.all_sample(batch_size=og_img.shape[0], + faded_recon_sample=og_img, + times=None) + + og_img = og_img.to('cpu') + blurry_imgs = xt_list[0].to('cpu') + deblurry_imgs = x0_list[-1].to('cpu') + direct_deblurry_imgs = x0_list[0].to('cpu') + + og_img = og_img.repeat(1, 3 // og_img.shape[1], 1, 1) + blurry_imgs = blurry_imgs.repeat(1, 3 // blurry_imgs.shape[1], 1, 1) + deblurry_imgs = deblurry_imgs.repeat(1, 3 // deblurry_imgs.shape[1], 1, 1) + direct_deblurry_imgs = direct_deblurry_imgs.repeat(1, 3 // direct_deblurry_imgs.shape[1], 1, 1) + + og_img = (og_img + 1) * 0.5 + blurry_imgs = (blurry_imgs + 1) * 0.5 + deblurry_imgs = (deblurry_imgs + 1) * 0.5 + direct_deblurry_imgs = (direct_deblurry_imgs + 1) * 0.5 + + if cnt == 0: + print(og_img.shape) + print(blurry_imgs.shape) + print(deblurry_imgs.shape) + print(direct_deblurry_imgs.shape) + + if sanity_check: + folder = './sanity_check/' + create_folder(folder) + + san_imgs = og_img[0: 32] + utils.save_image(san_imgs, str(folder + f'sample-og.png'), nrow=6) + + san_imgs = blurry_imgs[0: 32] + utils.save_image(san_imgs, str(folder + f'sample-xt.png'), nrow=6) + + san_imgs = deblurry_imgs[0: 32] + utils.save_image(san_imgs, str(folder + f'sample-recons.png'), nrow=6) + + san_imgs = direct_deblurry_imgs[0: 32] + utils.save_image(san_imgs, str(folder + f'sample-direct-recons.png'), nrow=6) + + if blurred_samples is None: + blurred_samples = blurry_imgs + else: + blurred_samples = torch.cat((blurred_samples, blurry_imgs), dim=0) + + if original_sample is None: + original_sample = og_img + else: + original_sample = torch.cat((original_sample, og_img), dim=0) + + if deblurred_samples is None: + deblurred_samples = deblurry_imgs + else: + deblurred_samples = torch.cat((deblurred_samples, deblurry_imgs), dim=0) + + if direct_deblurred_samples is None: + direct_deblurred_samples = direct_deblurry_imgs + else: + direct_deblurred_samples = torch.cat((direct_deblurred_samples, direct_deblurry_imgs), dim=0) + + cnt += og_img.shape[0] + + print(blurred_samples.shape) + print(original_sample.shape) + print(deblurred_samples.shape) + print(direct_deblurred_samples.shape) + + fid_blur = fid_func(samples=[original_sample, blurred_samples]) + rmse_blur = torch.sqrt(torch.mean((original_sample - blurred_samples) ** 2)) + ssim_blur = ssim(original_sample, blurred_samples, data_range=1, size_average=True) + print(f'The FID of blurry images with original image is {fid_blur}') + print(f'The RMSE of blurry images with original image is {rmse_blur}') + print(f'The SSIM of blurry images with original image is {ssim_blur}') + + fid_deblur = fid_func(samples=[original_sample, deblurred_samples]) + rmse_deblur = torch.sqrt(torch.mean((original_sample - deblurred_samples) ** 2)) + ssim_deblur = ssim(original_sample, deblurred_samples, data_range=1, size_average=True) + print(f'The FID of deblurred images with original image is {fid_deblur}') + print(f'The RMSE of deblurred images with original image is {rmse_deblur}') + print(f'The SSIM of deblurred images with original image is {ssim_deblur}') + + print(f'Hence the improvement in FID using sampling is {fid_blur - fid_deblur}') + + fid_direct_deblur = fid_func(samples=[original_sample, direct_deblurred_samples]) + rmse_direct_deblur = torch.sqrt(torch.mean((original_sample - direct_deblurred_samples) ** 2)) + ssim_direct_deblur = ssim(original_sample, direct_deblurred_samples, data_range=1, size_average=True) + print(f'The FID of direct deblurred images with original image is {fid_direct_deblur}') + print(f'The RMSE of direct deblurred images with original image is {rmse_direct_deblur}') + print(f'The SSIM of direct deblurred images with original image is {ssim_direct_deblur}') + + print(f'Hence the improvement in FID using direct sampling is {fid_blur - fid_direct_deblur}') + + def paper_invert_section_images(self, s_times=None): + + cnt = 0 + for i in range(50): + batches = self.batch_size + og_img = next(self.dl).cuda() + print(og_img.shape) + + x0_list, xt_list = self.ema_model.module.all_sample(batch_size=batches, + faded_recon_sample=og_img, + times=s_times) + og_img = (og_img + 1) * 0.5 + + for j in range(og_img.shape[0]//3): + original = og_img[j: j + 1] + utils.save_image(original, str(self.results_folder / f'original_{cnt}.png'), nrow=3) + + direct_recons = x0_list[0][j: j + 1] + direct_recons = (direct_recons + 1) * 0.5 + utils.save_image(direct_recons, str(self.results_folder / f'direct_recons_{cnt}.png'), nrow=3) + + sampling_recons = x0_list[-1][j: j + 1] + sampling_recons = (sampling_recons + 1) * 0.5 + utils.save_image(sampling_recons, str(self.results_folder / f'sampling_recons_{cnt}.png'), nrow=3) + + blurry_image = xt_list[0][j: j + 1] + blurry_image = (blurry_image + 1) * 0.5 + utils.save_image(blurry_image, str(self.results_folder / f'blurry_image_{cnt}.png'), nrow=3) + + blurry_image = cv2.imread(f'{self.results_folder}/blurry_image_{cnt}.png') + direct_recons = cv2.imread(f'{self.results_folder}/direct_recons_{cnt}.png') + sampling_recons = cv2.imread(f'{self.results_folder}/sampling_recons_{cnt}.png') + original = cv2.imread(f'{self.results_folder}/original_{cnt}.png') + + black = [0, 0, 0] + blurry_image = cv2.copyMakeBorder(blurry_image, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=black) + direct_recons = cv2.copyMakeBorder(direct_recons, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=black) + sampling_recons = cv2.copyMakeBorder(sampling_recons, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=black) + original = cv2.copyMakeBorder(original, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=black) + + im_h = cv2.hconcat([blurry_image, direct_recons, sampling_recons, original]) + cv2.imwrite(f'{self.results_folder}/all_{cnt}.png', im_h) + + cnt += 1 + + def paper_showing_diffusion_images(self, s_times=None): + + cnt = 0 + to_show = [0, 1, 2, 4, 8, 16, 32, 64, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100] + + for i in range(100): + batches = self.batch_size + og_img = next(self.dl).cuda() + print(og_img.shape) + + x0_list, xt_list = self.ema_model.module.all_sample(batch_size=batches, faded_recon_sample=og_img, times=s_times) + + for k in range(xt_list[0].shape[0]): + lst = [] + + for j in range(len(xt_list)): + x_t = xt_list[j][k] + x_t = (x_t + 1) * 0.5 + utils.save_image(x_t, str(self.results_folder / f'x_{len(xt_list)-j}_{cnt}.png'), nrow=1) + x_t = cv2.imread(f'{self.results_folder}/x_{len(xt_list)-j}_{cnt}.png') + if j in to_show: + lst.append(x_t) + + x_0 = x0_list[-1][k] + x_0 = (x_0 + 1) * 0.5 + utils.save_image(x_0, str(self.results_folder / f'x_best_{cnt}.png'), nrow=1) + x_0 = cv2.imread(f'{self.results_folder}/x_best_{cnt}.png') + lst.append(x_0) + im_h = cv2.hconcat(lst) + cv2.imwrite(f'{self.results_folder}/all_{cnt}.png', im_h) + cnt += 1 + + def test_from_data_save_results(self): + batch_size = 100 + dl = data.DataLoader(self.ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=16, + drop_last=True) + + all_samples = None + + for i, img in enumerate(dl, 0): + print(i) + print(img.shape) + if all_samples is None: + all_samples = img + else: + all_samples = torch.cat((all_samples, img), dim=0) + + # break + + # create_folder(f'{self.results_folder}/') + blurred_samples = None + original_sample = None + deblurred_samples = None + direct_deblurred_samples = None + + sanity_check = 1 + + orig_folder = f'{self.results_folder}_orig/' + create_folder(orig_folder) + + blur_folder = f'{self.results_folder}_blur/' + create_folder(blur_folder) + + d_deblur_folder = f'{self.results_folder}_d_deblur/' + create_folder(d_deblur_folder) + + deblur_folder = f'{self.results_folder}_deblur/' + create_folder(deblur_folder) + + cnt = 0 + while cnt < all_samples.shape[0]: + print(cnt) + og_x = all_samples[cnt: cnt + 32] + og_x = og_x.cuda() + og_x = og_x.type(torch.cuda.FloatTensor) + og_img = og_x + x0_list, xt_list = self.ema_model.module.all_sample(batch_size=og_img.shape[0], faded_recon_sample=og_img, times=None) + + og_img = og_img.to('cpu') + blurry_imgs = xt_list[0].to('cpu') + deblurry_imgs = x0_list[-1].to('cpu') + direct_deblurry_imgs = x0_list[0].to('cpu') + + og_img = og_img.repeat(1, 3 // og_img.shape[1], 1, 1) + blurry_imgs = blurry_imgs.repeat(1, 3 // blurry_imgs.shape[1], 1, 1) + deblurry_imgs = deblurry_imgs.repeat(1, 3 // deblurry_imgs.shape[1], 1, 1) + direct_deblurry_imgs = direct_deblurry_imgs.repeat(1, 3 // direct_deblurry_imgs.shape[1], 1, 1) + + og_img = (og_img + 1) * 0.5 + blurry_imgs = (blurry_imgs + 1) * 0.5 + deblurry_imgs = (deblurry_imgs + 1) * 0.5 + direct_deblurry_imgs = (direct_deblurry_imgs + 1) * 0.5 + + if cnt == 0: + print(og_img.shape) + print(blurry_imgs.shape) + print(deblurry_imgs.shape) + print(direct_deblurry_imgs.shape) + + if blurred_samples is None: + blurred_samples = blurry_imgs + else: + blurred_samples = torch.cat((blurred_samples, blurry_imgs), dim=0) + + if original_sample is None: + original_sample = og_img + else: + original_sample = torch.cat((original_sample, og_img), dim=0) + + if deblurred_samples is None: + deblurred_samples = deblurry_imgs + else: + deblurred_samples = torch.cat((deblurred_samples, deblurry_imgs), dim=0) + + if direct_deblurred_samples is None: + direct_deblurred_samples = direct_deblurry_imgs + else: + direct_deblurred_samples = torch.cat((direct_deblurred_samples, direct_deblurry_imgs), dim=0) + + cnt += og_img.shape[0] + + print(blurred_samples.shape) + print(original_sample.shape) + print(deblurred_samples.shape) + print(direct_deblurred_samples.shape) + + for i in range(blurred_samples.shape[0]): + utils.save_image(original_sample[i], f'{orig_folder}{i}.png', nrow=1) + utils.save_image(blurred_samples[i], f'{blur_folder}{i}.png', nrow=1) + utils.save_image(deblurred_samples[i], f'{deblur_folder}{i}.png', nrow=1) + utils.save_image(direct_deblurred_samples[i], f'{d_deblur_folder}{i}.png', nrow=1) diff --git a/MRI_recon/code/Frequency-Diffusion/draw/frequency_sampling.py b/MRI_recon/code/Frequency-Diffusion/draw/frequency_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..30633cf87161ffff68129e3b823d5f1b0d04f2c2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/draw/frequency_sampling.py @@ -0,0 +1,284 @@ +import torch + +from utils.k_degrade_utils import * + + +if __name__ == "__main__": + # First STEP + import matplotlib.pyplot as plt + import numpy as np, os + + os.makedirs("outputs", exist_ok=True) + + os.makedirs("outputs/low-fre-first", exist_ok=True) + os.makedirs("outputs/random-sample", exist_ok=True) + + + image_size = 256 + accelerated_factor = 6 + center_fraction = 0.04 + time_step = 25 + + + masks = get_ksu_kernel(time_step, image_size, "LogSamplingRate", + accelerated_factor=accelerated_factor, center_fraction=center_fraction) # LogSamplingRate + + + batch_size = 1 + + img = plt.imread("./assets/BraTS20_Training_001_86_t1.png") + img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LINEAR) + img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + + print("input img shape: ", img.shape) + + # to gray scale + if len(img.shape) == 3 and img.shape[-1] == 3: + img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + + + # img = np.transpose(img, (2, 0, 1)) + # img = img[0] + img = np.expand_dims(img, axis=0) + img = torch.from_numpy(img).unsqueeze(0).float() + original_img = img.clone() + + + rand_kernels = [] + rand_x = torch.randint(0, image_size + 1, (batch_size,)).long() + + img = img #* 2 - 1 # + + masked_img = [] + + for m in masks: + m = m.unsqueeze(0) + img = apply_ksu_kernel(img, m) + masked_img.append(img) + + save_masks = masks + masks = np.concatenate(masks, axis=-1)[0] + masked_img = torch.concat(masked_img, dim=-1).numpy() #+ 1) * 0.5 + + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + # masked_img = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY) + + + + img = np.concatenate([masks, masked_img], axis=0) + min_ = masked_img.min() + max_ = masked_img.max() + + out = img[image_size: 2 * image_size, : image_size] + fft, _ = apply_tofre(torch.from_numpy(out), torch.from_numpy(out)) # complex + fft = np.abs(fft.numpy()) + fft = np.log(fft) + fft = (fft - fft.min()) / (fft.max() - fft.min()) + + for i in range(time_step+1): + out = img[image_size: 2 * image_size, i * image_size: (i + 1) * image_size] + + # out = (out - out.min()) / (out.max() - out.min()) + out = (out - min_) / (max_ - min_) + plt.imsave(f"outputs/low-fre-first/{i}_image.png", out, cmap='gray') + + if i != 0: + out = img[:image_size, i * image_size:(i + 1) * image_size] + out = (out - out.min()) / (out.max() - out.min()) + + plt.imsave(f"outputs/low-fre-first/{i}_mask.png", out, cmap='gray') + + save_fft = fft * out + plt.imsave(f"outputs/low-fre-first/{i}_fft.png", save_fft, cmap='gray') + + + else: + diff = np.ones((image_size, image_size, 3), dtype=np.uint8) * 255 # All 255 (White)ve + ones = diff.astype(np.float32) / 255.0 + print("ones shape: ", ones.shape, ones.min(), ones.max()) + + plt.imsave(f"outputs/low-fre-first/{i}_mask.png", ones, cmap='gray') + plt.imsave(f"outputs/low-fre-first/{i}_fft.png", fft, cmap='gray') + + try: + diff = img[:image_size, (i-1) * image_size:(i) * image_size] - \ + img[:image_size, (i) * image_size:(i + 1) * image_size] + + except: + diff = np.zeros_like(img[:image_size, : image_size]) + + # plt.imsave(f"outputs/low-fre-first/{i}_mask_diff.png", diff, cmap='gray') + # print("diff shape: ", diff.shape, diff.min(), diff.max()) + + + diffsig = diff * fft + # save it as a red img, but the bg is trasparent + alpha_channel = np.full_like(diff, 255, dtype=np.uint8) * diff + alpha_channel = np.expand_dims(alpha_channel, axis=-1) + + diff = (diff * 255).astype(np.uint8) + diff = np.stack([diff, np.zeros_like(diff), np.zeros_like(diff)], axis=-1) + # Create an alpha channel (255 for full opacity) + + # Concatenate RGB with Alpha channel + diff = np.concatenate([diff, alpha_channel], axis=-1) + diff = diff.astype(np.uint8) + + # print("diff shape: ", diff.shape, diff.min(), diff.max()) + + + plt.imsave(f"outputs/low-fre-first/{i}_mask_diff_red.png", diff, cmap='gray') + plt.imsave(f"outputs/low-fre-first/{i}_mask_diffsig.png", diffsig, cmap='gray') + + + plt.imsave("outputs/masked_img.png", masked_img, cmap='gray') + plt.figure(figsize=(5*time_step, 10)) + plt.imshow(img, cmap='gray') # (1, 128, 1280) + plt.show() + + print("\n\nSecond stage...") + + # ------------------------------- ------------------------------- ------------------------------- + # ------------------------------- ------------------------------- ------------------------------- + + # Second STEP completely Random + import matplotlib.pyplot as plt + import numpy as np + + + final_mask = save_masks[-1][0].numpy() + new_masks = [] + + plt.imshow(final_mask, cmap='gray') + plt.show() + + height, width = final_mask.shape + print("final_mask shape: ", final_mask.shape) + + # Count ones and zeros + ones = np.sum(final_mask[0] == 1) + zeros = np.sum(final_mask[0] == 0) + + print("Initial ones count:", ones) + print("Initial zeros count:", zeros) + + # Identify initially filled and empty strips + initial_filled_indices = np.where(final_mask[0] == 1)[0] + remaining_indices = np.where(final_mask[0] == 0)[0] + + # Shuffle remaining indices to randomize filling order + np.random.shuffle(remaining_indices) + + # Split remaining indices into `time_step` parts + fills_per_step = np.array_split(remaining_indices, time_step) + + masked_img = [] # Store masks at each step + + # Copy initial mask + current_mask = final_mask.copy() + new_masks.append(current_mask.copy()) # Store initial state + + # Fill remaining strips over time + for i in range(time_step): + current_mask[:, fills_per_step[i - 1]] = 1 # Fill new strips + new_masks.append(current_mask.copy()) # Store new mask + # current_mask.append(final_mask) # Store new mask + + new_masks = new_masks[::-1] # Reverse list to get correct order + masked_img = [] + + for m in new_masks: + m = torch.from_numpy(m) #.unsqueeze(0) + + img = apply_ksu_kernel(original_img, m) + masked_img.append(img) + + masks = np.concatenate(new_masks, axis=-1) + masked_img = torch.concat(masked_img, dim=-1).numpy() #+ 1) * 0.5 + + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + + print("masked_img shape: ", masked_img.shape) + + + # masked_img = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY) + # masked_img = (masked_img - masked_img.min()) / (masked_img.max() - masked_img.min()) + + img = np.concatenate([masks, masked_img], axis=0) + + + min_ = masked_img.min() + max_ = masked_img.max() + + out = img[image_size: 2 * image_size, : image_size] + fft, _ = apply_tofre(torch.from_numpy(out), torch.from_numpy(out)) # complex + fft = np.abs(fft.numpy()) + fft = np.log(fft) + fft = (fft - fft.min()) / (fft.max() - fft.min()) + + + for i in range(time_step+1): + # if i % 3 != 0: + # continue + out = img[image_size : 2*image_size, i * image_size : (i + 1) * image_size] + + # out = (out - out.min()) / (out.max() - out.min()) + out = (out - min_) / (max_ - min_) + plt.imsave(f"outputs/random-sample/{i}_image.png", out, cmap='gray') + + if i != 0: + out = img[:image_size, i * image_size:(i + 1) * image_size] + out = (out - out.min()) / (out.max() - out.min()) + + plt.imsave(f"outputs/random-sample/{i}_mask.png", out, cmap='gray') + + save_fft = fft * out + plt.imsave(f"outputs/random-sample/{i}_fft.png", save_fft, cmap='gray') + + noise = np.random.normal(0, 0.2*np.log((time_step-i)+1), out.shape) * fft + save_fft = fft + noise * (1-out) # Sigma + plt.imsave(f"outputs/random-sample/{i}_fft_reverse.png", save_fft, cmap='gray') + + + else: + ones = np.ones_like(out) * 255 + plt.imsave(f"outputs/random-sample/{i}_mask.png", ones, cmap='gray') + plt.imsave(f"outputs/random-sample/{i}_fft.png", fft, cmap='gray') + + + try: + diff = img[:image_size, (i-1) * image_size:(i) * image_size] - \ + img[:image_size, (i) * image_size:(i + 1) * image_size] + + except: + diff = np.zeros_like(img[:image_size, : image_size]) + + + plt.imsave(f"outputs/random-sample/{i}_mask_diff.png", diff, cmap='gray') + # print("diff shape: ", diff.shape, diff.min(), diff.max()) + + # save it as a red img, but the bg is trasparent + alpha_channel = np.full_like(diff, 255, dtype=np.uint8) * diff + alpha_channel = np.expand_dims(alpha_channel, axis=-1) + + diff = (diff * 255).astype(np.uint8) + diff = np.stack([diff, np.zeros_like(diff), np.zeros_like(diff)], axis=-1) + # Create an alpha channel (255 for full opacity) + + # Concatenate RGB with Alpha channel + diff = np.concatenate([diff, alpha_channel], axis=-1) + diff = diff.astype(np.uint8) + + print("diff shape: ", diff.shape, diff.min(), diff.max()) + + plt.imsave(f"outputs/random-sample/{i}_mask_diff_red.png", diff, cmap='gray') + + + plt.imsave("outputs/img.png", img, cmap='gray') + # plt.figure(figsize=(5*time_step, 10)) + + plt.imshow(img, cmap='gray') # (1, 128, 1280) + plt.tight_layout() + plt.show() + + print("\n\nSecond stage...") diff --git a/MRI_recon/code/Frequency-Diffusion/draw/utils/k_degrade_utils.py b/MRI_recon/code/Frequency-Diffusion/draw/utils/k_degrade_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9991a66043ffd35ae4cac7fd64f8d4780f7e97f6 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/draw/utils/k_degrade_utils.py @@ -0,0 +1,312 @@ +# --------------------------- +# Fade kernels +# --------------------------- +import cv2, torch +import numpy as np +import torchgeometry as tgm +from torch.fft import fft2, ifft2, fftshift, ifftshift, fftn, ifftn +import sys, os + +from utils.mask_utils import RandomMaskFunc, EquispacedMaskFunc + + +from torch import nn +import matplotlib.pyplot as plt + +def get_fade_kernel(dims, std): + fade_kernel = tgm.image.get_gaussian_kernel2d(dims, std) + fade_kernel = fade_kernel / torch.max(fade_kernel) + fade_kernel = torch.ones_like(fade_kernel) - fade_kernel + # if device_of_kernel == 'cuda': + # fade_kernel = fade_kernel.cuda() + fade_kernel = fade_kernel[1:, 1:] + return fade_kernel + + + +def get_fade_kernels(fade_routine, num_timesteps, image_size, kernel_std,initial_mask): + kernels = [] + for i in range(num_timesteps): + if fade_routine == 'Incremental': + kernels.append(get_fade_kernel((image_size + 1, image_size + 1), + (kernel_std * (i + initial_mask), + kernel_std * (i + initial_mask)))) + elif fade_routine == 'Constant': + kernels.append(get_fade_kernel( + (image_size + 1, image_size + 1), + (kernel_std, kernel_std))) + + elif fade_routine == 'Random_Incremental': + kernels.append(get_fade_kernel((2 * image_size + 1, 2 * image_size + 1), + (kernel_std * (i + initial_mask), + kernel_std * (i + initial_mask)))) + return torch.stack(kernels) + + +# --------------------------- +# Kspace kernels +# --------------------------- +# cartesian_regular +def get_mask_func(mask_method, af, cf): + if mask_method == 'cartesian_regular': + return EquispacedMaskFractionFunc(center_fractions=[cf], accelerations=[af]) + + elif mask_method == 'cartesian_random': + return RandomMaskFunc(center_fractions=[cf], accelerations=[af]) + + elif mask_method == "random": + return RandomMaskFunc([cf], [af]) + + elif mask_method == "randompatch": + return RandomPatchFunc([cf], [af]) + + elif mask_method == "equispaced": + return EquispacedMaskFunc([cf], [af]) + + else: + raise NotImplementedError + + +use_fix_center_ratio = False + +class Noisy_Patch(nn.Module): + def __init__(self): + super(Noisy_Patch, self).__init__() + self.af_list = [] + self.cf_list = [] + self.fe_list = [] + self.pe_list = [] + self.seed = 0 + + def append_list(self, at, cf, fe, pe): + self.af_list.append(at) + self.cf_list.append(cf) + self.fe_list.append(fe) + self.pe_list.append(pe) + + def get_noisy_patches(self, t): + af = self.af_list[t] + cf = self.cf_list[t] + fe = self.fe_list[t] + pe = self.pe_list[t] + + patch_mask = get_mask_func("randompatch", af, cf) + mask_, _ = patch_mask((fe, pe, 1), seed=self.seed) # mask (numpy): (fe, pe) + return mask_ + + def forward(self, mask, ts): + # Step 1, Random Drop elements to learn intra-stride effects, patch-wise + # print("use_patch_kernel forward:", t) + # print("mask = ", mask.shape) + # masks_ = [] + for id, t in enumerate(ts): + mask_ = self.get_noisy_patches(t)[0] + # print("mask_ = ", mask_.shape) + # print("mask[id, t] =", mask[t].shape) + + mask[t] = mask_.to(mask[t].device) * mask[t] + self.seed += ts[0].item() + + # masks_ = torch.stack(masks_).cuda() + # print("masks_ = ", masks_.shape) + # print("mask = ", mask.shape) # B, T, H, W + + return mask + +get_noisy_patches = Noisy_Patch() + + +def get_ksu_mask(mask_method, af, cf, pe, fe, seed=0, is_training=False, sort_center=True): + mask_func = get_mask_func(mask_method, af, cf) # acceleration factor, center fraction + + if mask_method in ['cartesian_regular', 'cartesian_random', 'equispaced']: + print("pe:", pe) + mask = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + # prepare noisy patches + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, fe, pe) + + elif mask_method == 'equispaced': + mask = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + # prepare noisy patches + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, fe, pe) + + elif mask_method == 'gaussian_2d': + mask, _ = mask_func(resolution=(fe, pe), accel=af, sigma=100, seed=seed) # mask (numpy): (fe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (fe, pe) --> (1, fe, pe) + + elif mask_method in ['radial_add', 'radial_sub', 'spiral_add', 'spiral_sub']: + sr = 1 / af + mask = mask_func(mask_type=mask_method, mask_sr=sr, res=pe, seed=seed) # mask (numpy): (pe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (pe, pe) --> (1, pe, pe) + + else: + raise NotImplementedError + + # print("return mask = ", mask.shape) + return mask + + + +def get_ksu_kernel(timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=4, center_fraction=0.08, accelerate_mask=None, sort_center=True): + + if accelerated_factor == 4: + mask_method, center_fraction = "cartesian_random", center_fraction #0.08 # 0.15 + + else: + mask_method, center_fraction = "equispaced", center_fraction # 0.04 + + + center_ratio_factor = center_fraction * accelerated_factor + + masks = [] + noisy_masks = [] + ksu_mask_pe = ksu_mask_fe = image_size # , ksu_mask_pe=320, ksu_mask_fe=320 + # ksu_mask_fe + if ksu_routine == 'LinearSamplingRate': + # Generate the sampling rate list with torch.linspace, reversed, and skip the first element + sr_list = torch.linspace(start=1/accelerated_factor, end=1, steps=timesteps + 1).flip(0) + sr_list = [sr.item() for sr in sr_list] + # Start from 0.01 + for sr in sr_list: + # sr = sr.item() + af = 1 / sr # * accelerated_factor # acceleration factor + cf = center_fraction if use_fix_center_ratio else sr_list[0] * center_ratio_factor + + masks.append(get_ksu_mask(mask_method, af, cf, pe=ksu_mask_pe, fe=ksu_mask_fe)) + + elif ksu_routine == 'LogSamplingRate': + + # Generate the sampling rate list with torch.logspace, reversed, and skip the first element + sr_list = torch.logspace(start=-torch.log10(torch.tensor(accelerated_factor)), + end=0, steps=timesteps + 1).flip(0) + + sr_list = [sr.item() for sr in sr_list] + af = 1 / sr_list[-1] + cf = center_fraction if use_fix_center_ratio else sr_list[-1] * center_ratio_factor + + + if isinstance(accelerate_mask, type(None)): + cache_mask = get_ksu_mask(mask_method, af, cf, pe=ksu_mask_pe, fe=ksu_mask_fe, sort_center=sort_center) + print("cache_mask = ", cache_mask.shape) # torch.Size([1, 320, 320]) + else: + cache_mask = accelerate_mask + + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, ksu_mask_fe, ksu_mask_pe) + + masks.append(cache_mask) + + sr_list = sr_list[:-1][::-1] #.flip(0) # Flip? + + for sr in sr_list: + af = 1 / sr + cf = center_fraction if use_fix_center_ratio else sr * center_ratio_factor + # print("af = ", af, cf) + + H, W = cache_mask.shape[1], cache_mask.shape[2] + new_mask = cache_mask.clone() + + # Add additional lines to the mask based on new acceleration factor + total_lines = H + sampled_lines = int(total_lines / af) + existing_lines = new_mask.squeeze(0).sum(dim=0).nonzero(as_tuple=True)[0].tolist() + + remaining_lines = [i for i in range(total_lines) if i not in existing_lines] + + if sampled_lines > len(existing_lines): + center = W // 2 + additional_lines = sampled_lines - len(existing_lines) # sample number + + sorted_indices = sorted(remaining_lines, key=lambda x: abs(x - center)) + + # Take the closest `additional_lines` indices + sampled_indices = sorted_indices[:additional_lines] + + # Remove sampled indices from remaining_lines + for idx in sampled_indices: + remaining_lines.remove(idx) + + # Update new_mask for each sampled index + for idx in sampled_indices: + new_mask[:, :, idx] = 1.0 + + + + cache_mask = new_mask + + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, ksu_mask_fe, ksu_mask_pe) + + + masks.append(cache_mask) + + # reverse + masks = masks[::-1] + noisy_masks = masks # noisy_masks[::-1] + + + elif mask_method == 'gaussian_2d': + raise NotImplementedError("Gaussian 2D mask type is not implemented.") + + else: + raise NotImplementedError(f'Unknown k-space undersampling routine {ksu_routine}') + + # Return masks, excluding the first one + return masks + + + +class high_fre_mask: + def __init__(self): + self.mask_cache = {} + + def __call__(self, H, W): + if (H, W) in self.mask_cache: + return self.mask_cache[(H, W)] + center_x, center_y = H // 2, W // 2 + radius = H//8 # 影响的频率范围半径 + + high_freq_mask = torch.ones(H, W) + for i in range(H): + for j in range(W): + if (i - center_x) ** 2 + (j - center_y) ** 2 <= radius ** 2: + high_freq_mask[i, j] = 0.0 + self.mask_cache[(H, W)] = high_freq_mask + return high_freq_mask + + +high_fre_mask_cls = high_fre_mask() + + + +def apply_ksu_kernel(x_start, mask): + fft, mask = apply_tofre(x_start, mask) + fft = fft * mask + x_ksu = apply_to_spatial(fft) + + return x_ksu + +# from dataloaders.math import ifft2c, fft2c, complex_abs + +def apply_tofre(x_start, mask): + # B, C, H, W = x_start.shape + kspace = fftshift(fft2(x_start, norm=None, dim=(-2, -1)), dim=(-2, -1)) # Default: all dimensions + mask = mask.to(kspace.device) + return kspace, mask + +def apply_to_spatial(fft): + x_ksu = ifft2(ifftshift(fft, dim=(-2, -1)), norm=None, dim=(-2, -1)) # ortho + # After ifftn, the output is already in the spatial domain + x_ksu = x_ksu.real #torch.abs(x_ksu) # + return x_ksu + diff --git a/MRI_recon/code/Frequency-Diffusion/draw/utils/mask_utils.py b/MRI_recon/code/Frequency-Diffusion/draw/utils/mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..43b2d74d38a5add14c9815b3a883b53f7e49a0fc --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/draw/utils/mask_utils.py @@ -0,0 +1,195 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/LICENSE b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/README.md b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8f9aaadef4dd0210e6f11eb09f082c241e08051e --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/README.md @@ -0,0 +1,97 @@ +# FSMNet +FSMNet efficiently explores global dependencies across different modalities. Specifically, the features for each modality are extracted by the Frequency-Spatial Feature Extraction (FSFE) module, featuring a frequency branch and a spatial branch. Benefiting from the global property of the Fourier transform, the frequency branch can efficiently capture global dependency with an image-size receptive field, while the spatial branch can extract local features. To exploit complementary information from the auxiliary modality, we propose a Cross-Modal Selective fusion (CMS-fusion) module that selectively incorporate the frequency and spatial features from the auxiliary modality to enhance the corresponding branch of the target modality. To further integrate the enhanced global features from the frequency branch and the enhanced local features from the spatial branch, we develop a Frequency-Spatial fusion (FS-fusion) module, resulting in a comprehensive feature representation for the target modality. + +

+ +## Paper + +Accelerated Multi-Contrast MRI Reconstruction via Frequency and Spatial Mutual Learning
+[Qi Chen](https://scholar.google.com/citations?user=4Q5gs2MAAAAJ&hl=en)1, [Xiaohan Xing](https://hathawayxxh.github.io/)2, *, [Zhen Chen](https://franciszchen.github.io/)3, [Zhiwei Xiong](http://staff.ustc.edu.cn/~zwxiong/)1
+1 University of Science and Technology of China,
+2 Stanford University,
+3 Centre for Artificial Intelligence and Robotics (CAIR), HKISI-CAS
+MICCAI, 2024
+[paper](http://arxiv.org/abs/2409.14113) | [code](https://github.com/qic999/FSMNet) | [huggingface](https://huggingface.co/datasets/qicq1c/MRI_Reconstruction) + +## 0. Installation + +```bash +git clone https://github.com/qic999/FSMNet.git +cd FSMNet +``` + +See [installation instructions](documents/INSTALL.md) to create an environment and obtain requirements. + +## 1. Prepare datasets +Download BraTS dataset and fastMRI dataset and save them to the `datapath` directory. +``` +cd $datapath +# download brats dataset +wget https://huggingface.co/datasets/qicq1c/MRI_Reconstruction/resolve/main/BRATS_100patients.zip +unzip BRATS_100patients.zip +# download fastmri dataset +wget https://huggingface.co/datasets/qicq1c/MRI_Reconstruction/resolve/main/singlecoil_train_selected.zip +unzip singlecoil_train_selected.zip +``` + +## 2. Training +##### BraTS dataset, AF=4 +``` +python train_brats.py --root_path /data/qic99/MRI_recon image_100patients_4X/ \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --MRIDOWN 4X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_4x +``` + +##### BraTS dataset, AF=8 +``` +python train_brats.py --root_path /data/qic99/MRI_recon/image_100patients_8X/ \ + --gpu 1 --batch_size 4 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_8x +``` + +##### fastMRI dataset, AF=4 +``` +python train_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x +``` + +##### fastMRI dataset, AF=8 +``` +python train_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 1 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x +``` + +## 3. Testing +##### BraTS dataset, AF=4 +``` +python test_brats.py --root_path /data/qic99/MRI_recon/image_100patients_4X/ \ + --gpu 3 --base_lr 0.0001 --MRIDOWN 4X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_4x --phase test +``` + +##### BraTS dataset, AF=8 +``` +python test_brats.py --root_path /data/qic99/MRI_recon/image_100patients_8X/ \ + --gpu 4 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_8x --phase test +``` + +##### fastMRI dataset, AF=4 +``` +python test_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 5 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x --phase test +``` + +##### fastMRI dataset, AF=8 +``` +python test_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 6 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x --phase test +``` \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/bash/brats.sh b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/bash/brats.sh new file mode 100644 index 0000000000000000000000000000000000000000..f0925f21f2c63b2dac30a07f7bbcd9f07e20abdc --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/bash/brats.sh @@ -0,0 +1,37 @@ +# BraTS dataset, AF=4 +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion/experiments/FSMNet + +#root_path=/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee/image_100patients_4X/ +root_path=/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_4X/ + +python train_brats.py --root_path $root_path\ + --gpu 0 --batch_size 4 --base_lr 0.0001 --MRIDOWN 4X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_4x + +BraTS dataset, AF=8 + +python train_brats.py --root_path /gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_8X/ \ + --gpu 1 --batch_size 4 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_8x + + + + +# Test +BraTS dataset, AF=4 + +python test_brats.py --root_path /data/qic99/MRI_recon/image_100patients_4X/ \ + --gpu 3 --base_lr 0.0001 --MRIDOWN 4X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_4x --phase test + +BraTS dataset, AF=8 + +python test_brats.py --root_path /data/qic99/MRI_recon/image_100patients_8X/ \ + --gpu 4 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_8x --phase test + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/bash/fastmri.sh b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/bash/fastmri.sh new file mode 100644 index 0000000000000000000000000000000000000000..01c570029ae19591c104c360e0b5f2ae87292532 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/bash/fastmri.sh @@ -0,0 +1,32 @@ +# fastMRI dataset, AF=4 +# BraTS dataset, AF=4 +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion/experiments/FSMNet + +data_root=/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee +#data_root=/gamedrive/Datasets/medical/FrequencyDiffusion + + +python train_fastmri.py --root_path $data_root \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x + +# fastMRI dataset, AF=8 + +python train_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 1 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x + + +# Test +fastMRI dataset, AF=4 + +python test_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 5 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x --phase test + +fastMRI dataset, AF=8 + +python test_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 6 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x --phase test diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/BRATS_DuDo_dataloader.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/BRATS_DuDo_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..e2b06691ee683a347d4a20948d03598db65e9c08 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/BRATS_DuDo_dataloader.py @@ -0,0 +1,295 @@ +""" +dual-domain network的dataloader, 读取两个模态的under-sampled和fully-sampled kspace data, 以及high-quality image作为监督信号。 +""" + + +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import cv2 +import os +import torch +from torch.utils.data import Dataset + +from .kspace_subsample import undersample_mri, mri_fft + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + + return data + + + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, HF_refine = 'False', split='train', MRIDOWN='4X', SNR=15, \ + transform=None, input_round = None, input_normalize=None): + + super().__init__() + self._base_dir = base_dir + self.HF_refine = HF_refine + self.input_round = input_round + self._MRIDOWN = MRIDOWN + self._SNR = SNR + self.im_ids = [] + self.t2_images = [] + self.splits_path = "/data/xiaohan/BRATS_dataset/cv_splits_100patients/" + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + # self.test_file = self.splits_path + 'train_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + # test_images = os.listdir(self._base_dir) + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + self.t2_images.append(t2_path) + + + self.transform = transform + self.input_normalize = input_normalize + + assert (len(self.t1_images) == len(self.t2_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + + image_name = self.t1_images[index].split('t1')[0] + # print("image name:", image_name) + + t1 = np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0 + t2 = np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0 + # print("images:", t1_in.shape, t1.shape, t2_in.shape, t2.shape) + # print("t1 before standardization:", t1.max(), t1.min(), t1.mean()) + # print("loaded t1 range:", t1.max(), t1.min()) + # print("loaded t2 range:", t2.max(), t2 .min()) + + ### normalize the MRI image by divide_max + t1_max, t2_max = t1.max(), t2.max() + t1 = t1/t1_max + t2 = t2/t2_max + sample_stats = {"t1_max": t1_max, "t2_max": t2_max, "image_name": image_name} + + # sample_stats = {"t1_max": 1.0, "t2_max": 1.0} + + ### convert images to kspace and perform undersampling. + t1_kspace_in, t1_in, t1_kspace, t1_img = mri_fft(t1, _SNR = self._SNR) + t2_kspace_in, t2_in, t2_kspace, t2_img, mask = undersample_mri( + t2, _MRIDOWN = self._MRIDOWN, _SNR = self._SNR) + + + # print("loaded t2 range:", t2.max(), t2.min()) + # print("t2_under_img range:", t2_under_img.max(), t2_under_img.min()) + # print("t2_kspace real_part range:", t2_kspace.real.max(), t2_kspace.real.min()) + # print("t2_kspace imaginary_part range:", t2_kspace.imag.max(), t2_kspace.imag.min()) + # print("t2_kspace_in real_part range:", t2_kspace_in.real.max(), t2_kspace_in.real.min()) + # print("t2_kspace_in imaginary_part range:", t2_kspace_in.imag.max(), t2_kspace_in.imag.min()) + + if self.HF_refine == "False": + sample = {'t1': t1_img, 't1_in': t1_in, 't1_kspace': t1_kspace, 't1_kspace_in': t1_kspace_in, \ + 't2': t2_img, 't2_in': t2_in, 't2_kspace': t2_kspace, 't2_kspace_in': t2_kspace_in, \ + 't2_mask': mask} + + elif self.HF_refine == "True": + ### 读取上一步重建的kspace data. + t1_krecon_path = self._base_dir + self.t1_images[index].replace( + 't1.png', 't1_' + str(self._SNR) + 'dB_recon_kspace_' + self.input_round + '_DudoLoss.npy') + t2_krecon_path = self._base_dir + self.t1_images[index].replace('t1.png', 't2_' + self._MRIDOWN + \ + '_' + str(self._SNR) + 'dB_recon_kspace_' + self.input_round + '_DudoLoss.npy') + + t1_krecon = np.load(t1_krecon_path) + t2_krecon = np.load(t2_krecon_path) + # print("t1 and t2 recon kspace:", t1_krecon.shape, t2_krecon.shape) + # + sample = {'t1': t1_img, 't1_in': t1_in, 't1_kspace': t1_kspace, 't1_kspace_in': t1_kspace_in, \ + 't2': t2_img, 't2_in': t2_in, 't2_kspace': t2_kspace, 't2_kspace_in': t2_kspace_in, \ + 't2_mask': mask, 't1_krecon': t1_krecon, 't2_krecon': t2_krecon} + + + if self.transform is not None: + sample = self.transform(sample) + + return sample, sample_stats + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + # print("img_in before_numpy range:", img_in.max(), img_in.min()) + img = torch.from_numpy(img).float() + target = torch.from_numpy(target).float() + # print("img_in range:", img_in.max(), img_in.min()) + + return {'ct': img, 'mri': target} + + +# class ToTensor(object): +# """Convert ndarrays in sample to Tensors.""" + +# def __call__(self, sample): +# # swap color axis because +# # numpy image: H x W x C +# # torch image: C X H X W +# img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) +# img = sample['image'][:, :, None].transpose((2, 0, 1)) +# target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) +# target = sample['target'][:, :, None].transpose((2, 0, 1)) +# # print("img_in before_numpy range:", img_in.max(), img_in.min()) +# img_in = torch.from_numpy(img_in).float() +# img = torch.from_numpy(img).float() +# target_in = torch.from_numpy(target_in).float() +# target = torch.from_numpy(target).float() +# # print("img_in range:", img_in.max(), img_in.min()) + +# return {'ct_in': img_in, +# 'ct': img, +# 'mri_in': target_in, +# 'mri': target} diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/BRATS_dataloader.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/BRATS_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..52523fb1e080166812a64c191f92884cee244219 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/BRATS_dataloader.py @@ -0,0 +1,175 @@ +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import os +import torch +from torch.utils.data import Dataset + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', MRIDOWN='4X', SNR=15, transform=None): + + super().__init__() + self._base_dir = base_dir + self._MRIDOWN = MRIDOWN + self.im_ids = [] + self.t2_images = [] + self.t1_undermri_images, self.t2_undermri_images = [], [] + self.splits_path = "/home/xiaohan/datasets/BRATS_dataset/BRATS_2020_images/cv_splits/" + + + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + # test_images = os.listdir(self._base_dir) + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + if SNR == 0: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_undermri') + t1_under_path = image_path + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + else: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + t1_under_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + if MRIDOWN == "False": + t2_under_path = image_path.replace('t1', 't2_' + str(SNR) + 'dB') + else: + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + + # print("image paths:", image_path, t1_under_path, t2_path, t2_under_path) + + self.t2_images.append(t2_path) + self.t1_undermri_images.append(t1_under_path) + self.t2_undermri_images.append(t2_under_path) + + # print("t1 images:", self.t1_images) + # print("t2 images:", self.t2_images) + # print("t1_undermri_images:", self.t1_undermri_images) + # print("t2_undermri_images:", self.t2_undermri_images) + + self.transform = transform + + assert (len(self.t1_images) == len(self.t2_images)) + assert (len(self.t1_images) == len(self.t1_undermri_images)) + assert (len(self.t1_images) == len(self.t2_undermri_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + ### 两种settings. + ### 1. T1 fully-sampled 不加noise, T2 down-sampled, 做MRI acceleration. + ### 2. T1 fully-sampled 但是加noise, T2 down-sampled同时也加noise, 同时做MRI acceleration and enhancement. + ### T1, T2两个模态的输入都是low-quality images. + sample = {'image_in': np.array(Image.open(self._base_dir + self.t1_undermri_images[index]))/255.0, + 'image': np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0, + 'target_in': np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0, + 'target': np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0} + + + # ### 2023/05/23, Xiaohan, 把T1模态的输入改成high-quality图像(和ground truth一致,看能否为T2提供更好的guidance)。 + # sample = {'image_in': np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0, + # 'image': np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0, + # 'target_in': np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0, + # 'target': np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0} + + if self.transform is not None: + sample = self.transform(sample) + + return sample + + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 256, 256 + crop_size = 240 + pad_size = (256-240)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + # print("img_in:", img_in.shape) + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + return {'ct_in': img_in, + 'ct': img, + 'mri_in': target_in, + 'mri': target} diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/BRATS_dataloader_new.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/BRATS_dataloader_new.py new file mode 100644 index 0000000000000000000000000000000000000000..288e448bd06ffd5fd94e253e742dd29ed253ab34 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/BRATS_dataloader_new.py @@ -0,0 +1,371 @@ +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import cv2 +import os +import torch +from torch.utils.data import Dataset +from torchvision import transforms + + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', MRIDOWN='4X', \ + SNR=15, transform=None, input_normalize=None): + + super().__init__() + self._base_dir = base_dir + self._MRIDOWN = MRIDOWN + self.im_ids = [] + self.t2_images = [] + self.t1_undermri_images, self.t2_undermri_images = [], [] + self.t1_krecon_images, self.t2_krecon_images = [], [] + self.kspace_refine = "False" # ADD + + + name = base_dir.rstrip("/ ").split('/')[-1] + print("base_dir=", base_dir, ", folder name =", name) + self.splits_path = base_dir.replace(name, 'cv_splits_100patients/') + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + + if SNR == 0: + t1_under_path = image_path + + if self.kspace_refine == "False": + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + elif self.kspace_refine == "True": + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_krecon') + + if self.kspace_refine == "False": + t1_krecon_path = image_path + t2_krecon_path = image_path + + # if SNR == 0: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_undermri') + t1_under_path = image_path + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + + else: + t1_under_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB') + + t1_krecon_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_krecon_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB') + + + self.t2_images.append(t2_path) + self.t1_undermri_images.append(t1_under_path) + self.t2_undermri_images.append(t2_under_path) + + self.t1_krecon_images.append(t1_krecon_path) + self.t2_krecon_images.append(t2_krecon_path) + + self.transform = transform + self.input_normalize = input_normalize + + assert (len(self.t1_images) == len(self.t2_images)) + assert (len(self.t1_images) == len(self.t1_undermri_images)) + assert (len(self.t1_images) == len(self.t2_undermri_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + + t1_in = np.array(Image.open(self._base_dir + self.t1_undermri_images[index]))/255.0 + t1 = np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0 + t1_krecon = np.array(Image.open(self._base_dir + self.t1_krecon_images[index]))/255.0 + + t2_in = np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0 + t2 = np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0 + t2_krecon = np.array(Image.open(self._base_dir + self.t2_krecon_images[index]))/255.0 + + if self.input_normalize == "mean_std": + t1_in, t1_mean, t1_std = normalize_instance(t1_in, eps=1e-11) + t1 = normalize(t1, t1_mean, t1_std, eps=1e-11) + t2_in, t2_mean, t2_std = normalize_instance(t2_in, eps=1e-11) + t2 = normalize(t2, t2_mean, t2_std, eps=1e-11) + + t1_krecon = normalize(t1_krecon, t1_mean, t1_std, eps=1e-11) + t2_krecon = normalize(t2_krecon, t2_mean, t2_std, eps=1e-11) + + ### clamp input to ensure training stability. + t1_in = np.clip(t1_in, -6, 6) + t1 = np.clip(t1, -6, 6) + t2_in = np.clip(t2_in, -6, 6) + t2 = np.clip(t2, -6, 6) + + t1_krecon = np.clip(t1_krecon, -6, 6) + t2_krecon = np.clip(t2_krecon, -6, 6) + + sample_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + elif self.input_normalize == "min_max": + t1_in = (t1_in - t1_in.min())/(t1_in.max() - t1_in.min()) + t1 = (t1 - t1.min())/(t1.max() - t1.min()) + t2_in = (t2_in - t2_in.min())/(t2_in.max() - t2_in.min()) + t2 = (t2 - t2.min())/(t2.max() - t2.min()) + sample_stats = 0 + + elif self.input_normalize == "divide": + sample_stats = 0 + + + sample = {'image_in': t1_in, + 'image': t1, + 'image_krecon': t1_krecon, + 'target_in': t2_in, + 'target': t2, + 'target_krecon': t2_krecon} + + # print("images shape:", t1_in.shape, t1.shape, t2_in.shape, t2.shape) + + if self.transform is not None: + sample = self.transform(sample) + + return sample, sample_stats + + + +def add_gaussian_noise(img, mean=0, std=1): + noise = std * torch.randn_like(img) + mean + noisy_img = img + noise + return torch.clamp(noisy_img, 0, 1) + + + +class AddNoise(object): + def __call__(self, sample): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + add_gauss_noise = transforms.GaussianBlur(kernel_size=5) + add_poiss_noise = transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)) + + add_noise = transforms.RandomApply([add_gauss_noise, add_poiss_noise], p=0.5) + + img_in = add_noise(img_in) + target_in = add_noise(target_in) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + + return sample + + + + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 256, 256 + crop_size = 240 + pad_size = (256-240)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_krecon = sample['image_krecon'] + target_krecon = sample['target_krecon'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + img_krecon = np.pad(img_krecon, pad_size, mode='reflect') + target_krecon = np.pad(target_krecon, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + # print("img_in:", img_in.shape) + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + img_krecon = img_krecon[ww:ww+crop_size, hh:hh+crop_size] + target_krecon = target_krecon[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'image_krecon': img_krecon, \ + 'target_in': target_in, 'target': target, 'target_krecon': target_krecon} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + + +class RandomFlip(object): + def __call__(self, sample): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + # horizontal flip + if random.random() < 0.5: + img_in = cv2.flip(img_in, 1) + img = cv2.flip(img, 1) + target_in = cv2.flip(target_in, 1) + target = cv2.flip(target, 1) + + # vertical flip + if random.random() < 0.5: + img_in = cv2.flip(img_in, 0) + img = cv2.flip(img, 0) + target_in = cv2.flip(target_in, 0) + target = cv2.flip(target, 0) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + + + +class RandomRotate(object): + def __call__(self, sample, center=None, scale=1.0): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + degrees = [0, 90, 180, 270] + angle = random.choice(degrees) + + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + + img_in = cv2.warpAffine(img_in, matrix, (w, h)) + img = cv2.warpAffine(img, matrix, (w, h)) + target_in = cv2.warpAffine(target_in, matrix, (w, h)) + target = cv2.warpAffine(target, matrix, (w, h)) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + + image_krecon = sample['image_krecon'][:, :, None].transpose((2, 0, 1)) + target_krecon = sample['target_krecon'][:, :, None].transpose((2, 0, 1)) + + # print("img_in before_numpy range:", img_in.max(), img_in.min()) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + image_krecon = torch.from_numpy(image_krecon).float() + target_krecon = torch.from_numpy(target_krecon).float() + # print("img_in range:", img_in.max(), img_in.min()) + + return {'image_in': img_in, + 'image': img, + 'target_in': target_in, + 'target': target, + 'image_krecon': image_krecon, + 'target_krecon': target_krecon} diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/BRATS_kspace_dataloader.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/BRATS_kspace_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..871a153b20eac89e45ec0025e2aa31476360fde0 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/BRATS_kspace_dataloader.py @@ -0,0 +1,298 @@ +""" +Load the low-quality and high-quality images from the BRATS dataset and transform to kspace. +""" + + +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import cv2 +import os +import torch +from torch.utils.data import Dataset + +from .kspace_subsample import undersample_mri, mri_fft + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + + return data + + + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', MRIDOWN='4X', SNR=15, transform=None, input_normalize=None): + + super().__init__() + self._base_dir = base_dir + self._MRIDOWN = MRIDOWN + self.im_ids = [] + self.t2_images = [] + self.t1_undermri_images, self.t2_undermri_images = [], [] + self.splits_path = "/data/xiaohan/BRATS_dataset/cv_splits_100patients/" + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + # self.test_file = self.splits_path + 'train_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + # test_images = os.listdir(self._base_dir) + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + if SNR == 0: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_undermri') + t1_under_path = image_path + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + else: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + t1_under_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + + self.t2_images.append(t2_path) + self.t1_undermri_images.append(t1_under_path) + self.t2_undermri_images.append(t2_under_path) + + # print("t1 images:", self.t1_images) + # print("t2 images:", self.t2_images) + # print("t1_undermri_images:", self.t1_undermri_images) + # print("t2_undermri_images:", self.t2_undermri_images) + + self.transform = transform + self.input_normalize = input_normalize + + assert (len(self.t1_images) == len(self.t2_images)) + assert (len(self.t1_images) == len(self.t1_undermri_images)) + assert (len(self.t1_images) == len(self.t2_undermri_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + + # t1_in = np.array(Image.open(self._base_dir + self.t1_undermri_images[index]))/255.0 + t1 = np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0 + # t2_in = np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0 + t2 = np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0 + # print("images:", t1_in.shape, t1.shape, t2_in.shape, t2.shape) + # print("t1 before standardization:", t1.max(), t1.min(), t1.mean()) + # print("t1 range:", t1.max(), t1.min()) + # print("t2 range:", t2.max(), t2 .min()) + + if self.input_normalize == "mean_std": + ### 对input image和target image都做(x-mean)/std的归一化操作 + t1, t1_mean, t1_std = normalize_instance(t1, eps=1e-11) + t2, t2_mean, t2_std = normalize_instance(t2, eps=1e-11) + + ### clamp input to ensure training stability. + t1 = np.clip(t1, -6, 6) + t2 = np.clip(t2, -6, 6) + # print("t1 after standardization:", t1.max(), t1.min(), t1.mean()) + + sample_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + elif self.input_normalize == "min_max": + # t1 = (t1 - t1.min())/(t1.max() - t1.min()) + # t2 = (t2 - t2.min())/(t2.max() - t2.min()) + t1 = t1/t1.max() + t2 = t2/t2.max() + sample_stats = 0 + + elif self.input_normalize == "divide": + sample_stats = 0 + + + ### convert images to kspace and perform undersampling. + # t1_kspace, t1_masked_kspace, t1_img, t1_under_img = undersample_mri(t1, _MRIDOWN = None) + t1_kspace, t1_img = mri_fft(t1) + t2_kspace, t2_masked_kspace, t2_img, t2_under_img, mask = undersample_mri(t2, _MRIDOWN = self._MRIDOWN) + + + sample = {'t1': t1_img, 't2': t2_img, 'under_t2': t2_under_img, "t2_mask": mask, \ + 't1_kspace': t1_kspace, 't2_kspace': t2_kspace, 't2_masked_kspace': t2_masked_kspace} + + if self.transform is not None: + sample = self.transform(sample) + + return sample, sample_stats + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + # print("img_in before_numpy range:", img_in.max(), img_in.min()) + img = torch.from_numpy(img).float() + target = torch.from_numpy(target).float() + # print("img_in range:", img_in.max(), img_in.min()) + + return {'ct': img, 'mri': target} + + +# class ToTensor(object): +# """Convert ndarrays in sample to Tensors.""" + +# def __call__(self, sample): +# # swap color axis because +# # numpy image: H x W x C +# # torch image: C X H X W +# img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) +# img = sample['image'][:, :, None].transpose((2, 0, 1)) +# target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) +# target = sample['target'][:, :, None].transpose((2, 0, 1)) +# # print("img_in before_numpy range:", img_in.max(), img_in.min()) +# img_in = torch.from_numpy(img_in).float() +# img = torch.from_numpy(img).float() +# target_in = torch.from_numpy(target_in).float() +# target = torch.from_numpy(target).float() +# # print("img_in range:", img_in.max(), img_in.min()) + +# return {'ct_in': img_in, +# 'ct': img, +# 'mri_in': target_in, +# 'mri': target} diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/__init__.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/fastmri.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/fastmri.py new file mode 100644 index 0000000000000000000000000000000000000000..0502ac9ccb96df4f55908cd92d5db432239659fe --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/fastmri.py @@ -0,0 +1,222 @@ +import csv +import os +import random +import xml.etree.ElementTree as etree +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +import pathlib + +import h5py +import numpy as np +import torch +import yaml +from torch.utils.data import Dataset +from .transforms import build_transforms +from matplotlib import pyplot as plt + + +def fetch_dir(key, data_config_file=pathlib.Path("fastmri_dirs.yaml")): + """ + Data directory fetcher. + + This is a brute-force simple way to configure data directories for a + project. Simply overwrite the variables for `knee_path` and `brain_path` + and this function will retrieve the requested subsplit of the data for use. + + Args: + key (str): key to retrieve path from data_config_file. + data_config_file (pathlib.Path, + default=pathlib.Path("fastmri_dirs.yaml")): Default path config + file. + + Returns: + pathlib.Path: The path to the specified directory. + """ + if not data_config_file.is_file(): + default_config = dict( + knee_path="/home/jc3/Data/", + brain_path="/home/jc3/Data/", + ) + with open(data_config_file, "w") as f: + yaml.dump(default_config, f) + + raise ValueError(f"Please populate {data_config_file} with directory paths.") + + with open(data_config_file, "r") as f: + data_dir = yaml.safe_load(f)[key] + + data_dir = pathlib.Path(data_dir) + + if not data_dir.exists(): + raise ValueError(f"Path {data_dir} from {data_config_file} does not exist.") + + return data_dir + + +def et_query( + root: etree.Element, + qlist: Sequence[str], + namespace: str = "http://www.ismrm.org/ISMRMRD", +) -> str: + """ + ElementTree query function. + This can be used to query an xml document via ElementTree. It uses qlist + for nested queries. + Args: + root: Root of the xml to search through. + qlist: A list of strings for nested searches, e.g. ["Encoding", + "matrixSize"] + namespace: Optional; xml namespace to prepend query. + Returns: + The retrieved data as a string. + """ + s = "." + prefix = "ismrmrd_namespace" + + ns = {prefix: namespace} + + for el in qlist: + s = s + f"//{prefix}:{el}" + + value = root.find(s, ns) + if value is None: + raise RuntimeError("Element not found") + + return str(value.text) + + +class SliceDataset(Dataset): + def __init__( + self, + root, + transform, + challenge, + sample_rate=1, + mode='train' + ): + self.mode = mode + + # challenge + if challenge not in ("singlecoil", "multicoil"): + raise ValueError('challenge should be either "singlecoil" or "multicoil"') + self.recons_key = ( + "reconstruction_esc" if challenge == "singlecoil" else "reconstruction_rss" + ) + # transform + self.transform = transform + + self.examples = [] + + self.cur_path = root + if not os.path.exists(self.cur_path): + self.cur_path = self.cur_path + "_selected" + + self.csv_file = "knee_data_split/singlecoil_" + self.mode + "_split_less.csv" + + with open(self.csv_file, 'r') as f: + reader = csv.reader(f) + + id = 0 + + for row in reader: + pd_metadata, pd_num_slices = self._retrieve_metadata(os.path.join(self.cur_path, row[0] + '.h5')) + + pdfs_metadata, pdfs_num_slices = self._retrieve_metadata(os.path.join(self.cur_path, row[1] + '.h5')) + + for slice_id in range(min(pd_num_slices, pdfs_num_slices)): + self.examples.append( + (os.path.join(self.cur_path, row[0] + '.h5'), os.path.join(self.cur_path, row[1] + '.h5') + , slice_id, pd_metadata, pdfs_metadata, id)) + id += 1 + + if sample_rate < 1: + random.shuffle(self.examples) + num_examples = round(len(self.examples) * sample_rate) + + self.examples = self.examples[0:num_examples] + + def __len__(self): + return len(self.examples) + + def __getitem__(self, i): + + # read pd + pd_fname, pdfs_fname, slice, pd_metadata, pdfs_metadata, id = self.examples[i] + + with h5py.File(pd_fname, "r") as hf: + pd_kspace = hf["kspace"][slice] + + pd_mask = np.asarray(hf["mask"]) if "mask" in hf else None + + pd_target = hf[self.recons_key][slice] if self.recons_key in hf else None + + attrs = dict(hf.attrs) + + attrs.update(pd_metadata) + + if self.transform is None: + pd_sample = (pd_kspace, pd_mask, pd_target, attrs, pd_fname, slice) + else: + pd_sample = self.transform(pd_kspace, pd_mask, pd_target, attrs, pd_fname, slice) + + with h5py.File(pdfs_fname, "r") as hf: + pdfs_kspace = hf["kspace"][slice] + pdfs_mask = np.asarray(hf["mask"]) if "mask" in hf else None + + pdfs_target = hf[self.recons_key][slice] if self.recons_key in hf else None + + attrs = dict(hf.attrs) + + attrs.update(pdfs_metadata) + + if self.transform is None: + pdfs_sample = (pdfs_kspace, pdfs_mask, pdfs_target, attrs, pdfs_fname, slice) + else: + pdfs_sample = self.transform(pdfs_kspace, pdfs_mask, pdfs_target, attrs, pdfs_fname, slice) + + + # dataset pdf mean and std tensor(3.1980e-05) tensor(1.3093e-05) + # print("dataset pdf mean and std", pdfs_sample[2], pdfs_sample[3]) + # print(pdfs_sample[1].shape, pdfs_sample[1].min(), pdfs_sample[1].max()) + + return (pd_sample, pdfs_sample, id) + + def _retrieve_metadata(self, fname): + with h5py.File(fname, "r") as hf: + et_root = etree.fromstring(hf["ismrmrd_header"][()]) + + enc = ["encoding", "encodedSpace", "matrixSize"] + enc_size = ( + int(et_query(et_root, enc + ["x"])), + int(et_query(et_root, enc + ["y"])), + int(et_query(et_root, enc + ["z"])), + ) + rec = ["encoding", "reconSpace", "matrixSize"] + recon_size = ( + int(et_query(et_root, rec + ["x"])), + int(et_query(et_root, rec + ["y"])), + int(et_query(et_root, rec + ["z"])), + ) + + lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"] + enc_limits_center = int(et_query(et_root, lims + ["center"])) + enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1 + + padding_left = enc_size[1] // 2 - enc_limits_center + padding_right = padding_left + enc_limits_max + + num_slices = hf["kspace"].shape[0] + + metadata = { + "padding_left": padding_left, + "padding_right": padding_right, + "encoding_size": enc_size, + "recon_size": recon_size, + } + + return metadata, num_slices + + +def build_dataset(args, mode='train', sample_rate=1): + assert mode in ['train', 'val', 'test'], 'unknown mode' + transforms = build_transforms(args, mode) + return SliceDataset(os.path.join(args.root_path, 'singlecoil_' + mode), transforms, 'singlecoil', sample_rate=sample_rate, mode=mode) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/hybrid_sparse.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/hybrid_sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..42e4a7e33c2204c13a1c4509897baf19e1fb07f1 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/hybrid_sparse.py @@ -0,0 +1,156 @@ +from __future__ import print_function, division +import numpy as np +from glob import glob +import random +from skimage import transform + +import torch +from torch.utils.data import Dataset + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', transform=None): + + super().__init__() + self._base_dir = base_dir + self.im_ids = [] + self.images = [] + self.gts = [] + + if split=='train': + self._image_dir = self._base_dir + imagelist = glob(self._image_dir+"/*_ct.png") + imagelist=sorted(imagelist) + for image_path in imagelist: + gt_path = image_path.replace('ct', 't1') + self.images.append(image_path) + self.gts.append(gt_path) + + + elif split=='test': + self._image_dir = self._base_dir + imagelist = glob(self._image_dir + "/*_ct.png") + imagelist=sorted(imagelist) + for image_path in imagelist: + gt_path = image_path.replace('ct', 't1') + self.images.append(image_path) + self.gts.append(gt_path) + + self.transform = transform + + assert (len(self.images) == len(self.gts)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.images))) + + def __len__(self): + return len(self.images) + + + def __getitem__(self, index): + img_in, img, target_in, target= self._make_img_gt_point_pair(index) + sample = {'image_in': img_in, 'image':img, 'target_in': target_in, 'target': target} + # print("image in:", img_in.shape) + + if self.transform is not None: + sample = self.transform(sample) + + return sample + + + def _make_img_gt_point_pair(self, index): + # Read Image and Target + + # the default setting (i.e., rawdata.npz) is 4X64P + dd = np.load(self.images[index].replace('.png', '_raw_4X64P.npz')) + # print("images range:", dd['fbp'].max(), dd['ct'].max(), dd['under_t1'].max(), dd['t1'].max()) + _img_in = dd['fbp'] + _img_in[_img_in>0.6]=0.6 + _img_in = _img_in/0.6 + + _img = dd['ct'] + _img =(_img/1000*0.192+0.192) + _img[_img<0.0]=0.0 + _img[_img>0.6]=0.6 + _img = _img/0.6 + + _target_in = dd['under_t1'] + _target = dd['t1'] + + return _img_in, _img, _target_in, _target + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 400, 400 + crop_size = 384 + pad_size = (400-384)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + return {'ct_in': img_in, + 'ct': img, + 'mri_in': target_in, + 'mri': target} diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/kspace_subsample.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/kspace_subsample.py new file mode 100644 index 0000000000000000000000000000000000000000..fb5b5694d8fee8b35ba8394fae98fe2d3aa25759 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/kspace_subsample.py @@ -0,0 +1,287 @@ +""" +2023/10/16, +preprocess kspace data with the undersampling mask in the fastMRI project. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +## mri related +def mri_fourier_transform_2d(image, mask): + ''' + image: input tensor [B, H, W, C] + mask: mask tensor [H, W] + ''' + spectrum = torch.fft.fftn(image, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + spectrum = torch.fft.fftshift(spectrum, dim=(1, 2)) + # Downsample k-space + masked_spectrum = spectrum * mask[None, :, :, None] + return spectrum, masked_spectrum + + +## mri related +def mri_inver_fourier_transform_2d(spectrum): + ''' + image: input tensor [B, H, W, C] + ''' + spectrum = torch.fft.ifftshift(spectrum, dim=(1, 2)) + image = torch.fft.ifftn(spectrum, dim=(1, 2), norm='ortho') + + return image + + +def add_gaussian_noise(kspace, snr): + ### 根据SNR确定noise的放大比例 + num_pixels = kspace.shape[0]*kspace.shape[1]*kspace.shape[2]*kspace.shape[3] + psr = torch.sum(torch.abs(kspace.real)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + noise_r = torch.randn_like(kspace.real)*np.sqrt(pnr) + + psim = torch.sum(torch.abs(kspace.imag)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = torch.randn_like(kspace.imag)*np.sqrt(pnim) + + noise = noise_r + 1j*noise_im + noisy_kspace = kspace + noise + + return noisy_kspace + + +def mri_fft(raw_mri, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + spectrum = torch.fft.fftn(mri, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + kspace = torch.fft.fftshift(spectrum, dim=(1, 2)) + + if _SNR > 0: + noisy_kspace = add_gaussian_noise(kspace, _SNR) + else: + noisy_kspace = kspace + + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + + return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ + kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1) + + + +def undersample_mri(raw_mri, _MRIDOWN, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + if _MRIDOWN == "4X": + mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 + elif _MRIDOWN == "8X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + + ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + + shape = [240, 240, 1] + mask = ff(shape, seed=1337) + mask = mask[:, :, 0] # [1, 240] + # print("mask:", mask.shape) + # print("original MRI:", mri) + + # print("original MRI:", mri.shape) + ### under-sample the kspace data. + kspace, masked_kspace = mri_fourier_transform_2d(mri, mask) + ### add low-field noise to the kspace data. + if _SNR > 0: + noisy_kspace = add_gaussian_noise(masked_kspace, _SNR) + else: + noisy_kspace = masked_kspace + + ### conver the corrupted kspace data back to noisy MRI image. + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + + return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ + kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1), mask.unsqueeze(-1) + + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/math.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/math.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0d76246b0c8fb757b6b589b7f889e0316696d0 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/math.py @@ -0,0 +1,272 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +import numpy as np + +def complex_mul(x, y): + """ + Complex multiplication. + + This multiplies two complex tensors assuming that they are both stored as + real arrays with the last dimension being the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == y.shape[-1] == 2 + re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] + im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] + + return torch.stack((re, im), dim=-1) + + +def complex_conj(x): + """ + Complex conjugate. + + This applies the complex conjugate assuming that the input array has the + last dimension as the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == 2 + + return torch.stack((x[..., 0], -x[..., 1]), dim=-1) + + +# def fft2c(data): +# """ +# Apply centered 2 dimensional Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The FFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# data = torch.fft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + +# def ifft2c(data): +# """ +# Apply centered 2-dimensional Inverse Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The IFFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# # data = torch.ifft(data, 2, normalized=True) +# data = torch.fft.ifft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + + +def fft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2 dimensional Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.fft``. + + Returns: + The FFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.fftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.ifft``. + + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + + + + +def complex_abs(data): + """ + Compute the absolute value of a complex valued input tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Absolute value of data. + """ + assert data.size(-1) == 2 + + return (data ** 2).sum(dim=-1).sqrt() + + + +def complex_abs_numpy(data): + assert data.shape[-1] == 2 + + return np.sqrt(np.sum(data ** 2, axis=-1)) + + +def complex_abs_sq(data):#multi coil + """ + Compute the squared absolute value of a complex tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Squared absolute value of data. + """ + assert data.size(-1) == 2 + return (data ** 2).sum(dim=-1) + + +# Helper functions + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data + """ + data = data.numpy() + return data[..., 0] + 1j * data[..., 1] diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/subsample.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/subsample.py new file mode 100644 index 0000000000000000000000000000000000000000..d0620da3414c6077e4293376fb8a9be01ad19990 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/subsample.py @@ -0,0 +1,195 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/transforms.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..536eecc5bef52a969001f5f68fc91a38fdc549ba --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/dataloaders/transforms.py @@ -0,0 +1,485 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import numpy as np +import torch + +from .math import ifft2c, fft2c, complex_abs +from .subsample import create_mask_for_mask_type, MaskFunc +import random + +from typing import Dict, Optional, Sequence, Tuple, Union +from matplotlib import pyplot as plt +import os + +def rss(data, dim=0): + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Args: + data (torch.Tensor): The input tensor + dim (int): The dimensions along which to apply the RSS transform + + Returns: + torch.Tensor: The RSS value. + """ + return torch.sqrt((data ** 2).sum(dim)) + + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Args: + data (np.array): Input numpy array. + + Returns: + torch.Tensor: PyTorch version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data. + """ + data = data.numpy() + + return data[..., 0] + 1j * data[..., 1] + + +def apply_mask(data, mask_func, seed=None, padding=None): + """ + Subsample given k-space by multiplying with a mask. + + Args: + data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where + dimensions -3 and -2 are the spatial dimensions, and the final dimension has size + 2 (for complex values). + mask_func (callable): A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed (int or 1-d array_like, optional): Seed for the random number generator. + + Returns: + (tuple): tuple containing: + masked data (torch.Tensor): Subsampled k-space data + mask (torch.Tensor): The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + + +def mask_center(x, mask_from, mask_to): + mask = torch.zeros_like(x) + mask[:, :, :, mask_from:mask_to] = x[:, :, :, mask_from:mask_to] + + return mask + + +def center_crop(data, shape): + """ + Apply a center crop to the input real image or batch of real images. + + Args: + data (torch.Tensor): The input tensor to be center cropped. It should + have at least 2 dimensions and the cropping is applied along the + last two dimensions. + shape (int, int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image. + """ + assert 0 < shape[0] <= data.shape[-2] + assert 0 < shape[1] <= data.shape[-1] + + w_from = (data.shape[-2] - shape[0]) // 2 + h_from = (data.shape[-1] - shape[1]) // 2 + w_to = w_from + shape[0] + h_to = h_from + shape[1] + + return data[..., w_from:w_to, h_from:h_to] + + +def complex_center_crop(data, shape): + """ + Apply a center crop to the input image or batch of complex images. + + Args: + data (torch.Tensor): The complex input tensor to be center cropped. It + should have at least 3 dimensions and the cropping is applied along + dimensions -3 and -2 and the last dimensions should have a size of + 2. + shape (int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image + """ + assert 0 < shape[0] <= data.shape[-3] + assert 0 < shape[1] <= data.shape[-2] + + w_from = (data.shape[-3] - shape[0]) // 2 #80 + h_from = (data.shape[-2] - shape[1]) // 2 #80 + w_to = w_from + shape[0] #240 + h_to = h_from + shape[1] #240 + + return data[..., w_from:w_to, h_from:h_to, :] + + +def center_crop_to_smallest(x, y): + """ + Apply a center crop on the larger image to the size of the smaller. + + The minimum is taken over dim=-1 and dim=-2. If x is smaller than y at + dim=-1 and y is smaller than x at dim=-2, then the returned dimension will + be a mixture of the two. + + Args: + x (torch.Tensor): The first image. + y (torch.Tensor): The second image + + Returns: + tuple: tuple of tensors x and y, each cropped to the minimim size. + """ + smallest_width = min(x.shape[-1], y.shape[-1]) + smallest_height = min(x.shape[-2], y.shape[-2]) + x = center_crop(x, (smallest_height, smallest_width)) + y = center_crop(y, (smallest_height, smallest_width)) + + return x, y + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + +class DataTransform(object): + """ + Data Transformer for training U-Net models. + """ + + def __init__(self, which_challenge): + """ + Args: + which_challenge (str): Either "singlecoil" or "multicoil" denoting + the dataset. + mask_func (fastmri.data.subsample.MaskFunc): A function that can + create a mask of appropriate shape. + use_seed (bool): If true, this class computes a pseudo random + number generator seed from the filename. This ensures that the + same mask is used for all the slices of a given volume every + time. + """ + if which_challenge not in ("singlecoil", "multicoil"): + raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"') + + self.which_challenge = which_challenge + + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + """ + Args: + kspace (numpy.array): Input k-space of shape (num_coils, rows, + cols, 2) for multi-coil data or (rows, cols, 2) for single coil + data. + mask (numpy.array): Mask from the test dataset. + target (numpy.array): Target image. + attrs (dict): Acquisition related information stored in the HDF5 + object. + fname (str): File name. + slice_num (int): Serial number of the slice. + + Returns: + (tuple): tuple containing: + image (torch.Tensor): Zero-filled input image. + target (torch.Tensor): Target image converted to a torch + Tensor. + mean (float): Mean value used for normalization. + std (float): Standard deviation value used for normalization. + fname (str): File name. + slice_num (int): Serial number of the slice. + """ + kspace = to_tensor(kspace) + + # inverse Fourier transform to get zero filled solution + image = ifft2c(kspace) + + # crop input to correct size + if target is not None: + crop_size = (target.shape[-2], target.shape[-1]) + else: + crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) + + # check for sFLAIR 203 + if image.shape[-2] < crop_size[1]: + crop_size = (image.shape[-2], image.shape[-2]) + + image = complex_center_crop(image, crop_size) + + # getLR + imgfft = fft2c(image) + imgfft = complex_center_crop(imgfft, (160, 160)) + LR_image = ifft2c(imgfft) + + # absolute value + LR_image = complex_abs(LR_image) + + # normalize input + LR_image, mean, std = normalize_instance(LR_image, eps=1e-11) + LR_image = LR_image.clamp(-6, 6) + + # normalize target + if target is not None: + target = to_tensor(target) + target = center_crop(target, crop_size) + target = normalize(target, mean, std, eps=1e-11) + target = target.clamp(-6, 6) + else: + target = torch.Tensor([0]) + + return LR_image, target, mean, std, fname, slice_num + +class DenoiseDataTransform(object): + def __init__(self, size, noise_rate): + super(DenoiseDataTransform, self).__init__() + self.size = (size, size) + self.noise_rate = noise_rate + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + max_value = attrs["max"] + + #target + target = to_tensor(target) + target = center_crop(target, self.size) + target, mean, std = normalize_instance(target, eps=1e-11) + target = target.clamp(-6, 6) + + #image + kspace = to_tensor(kspace) + complex_image = ifft2c(kspace) #complex_image + image = complex_center_crop(complex_image, self.size) + noise_image = self.rician_noise(image, max_value) + noise_image = complex_abs(noise_image) + + noise_image = normalize(noise_image, mean, std, eps=1e-11) + noise_image = noise_image.clamp(-6, 6) + + return noise_image, target, mean, std, fname, slice_num + + + def rician_noise(self, X, noise_std): + #Add rician noise with variance sampled uniformly from the range 0 and 0.1 + noise_std = random.uniform(0, noise_std*self.noise_rate) + Ir = X + noise_std * torch.randn(X.shape) + Ii = noise_std*torch.randn(X.shape) + In = torch.sqrt(Ir ** 2 + Ii ** 2) + return In + + +def apply_mask( + data: torch.Tensor, + mask_func: MaskFunc, + seed: Optional[Union[int, Tuple[int, ...]]] = None, + padding: Optional[Sequence[int]] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Subsample given k-space by multiplying with a mask. + Args: + data: The input k-space data. This should have at least 3 dimensions, + where dimensions -3 and -2 are the spatial dimensions, and the + final dimension has size 2 (for complex values). + mask_func: A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed: Seed for the random number generator. + padding: Padding value to apply for mask. + Returns: + tuple containing: + masked data: Subsampled k-space data + mask: The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + + + +class ReconstructionTransform(object): + """ + Data Transformer for training U-Net models. + """ + + def __init__(self, which_challenge, mask_func=None, use_seed=True): + """ + Args: + which_challenge (str): Either "singlecoil" or "multicoil" denoting + the dataset. + mask_func (fastmri.data.subsample.MaskFunc): A function that can + create a mask of appropriate shape. + use_seed (bool): If true, this class computes a pseudo random + number generator seed from the filename. This ensures that the + same mask is used for all the slices of a given volume every + time. + """ + if which_challenge not in ("singlecoil", "multicoil"): + raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"') + + self.mask_func = mask_func + self.which_challenge = which_challenge + self.use_seed = use_seed + + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + """ + Args: + kspace (numpy.array): Input k-space of shape (num_coils, rows, + cols, 2) for multi-coil data or (rows, cols, 2) for single coil + data. + mask (numpy.array): Mask from the test dataset. + target (numpy.array): Target image. + attrs (dict): Acquisition related information stored in the HDF5 + object. + fname (str): File name. + slice_num (int): Serial number of the slice. + + Returns: + (tuple): tuple containing: + image (torch.Tensor): Zero-filled input image. + target (torch.Tensor): Target image converted to a torch + Tensor. + mean (float): Mean value used for normalization. + std (float): Standard deviation value used for normalization. + fname (str): File name. + slice_num (int): Serial number of the slice. + """ + kspace = to_tensor(kspace) + + # apply mask + if self.mask_func: + seed = None if not self.use_seed else tuple(map(ord, fname)) + masked_kspace, mask = apply_mask(kspace, self.mask_func, seed) + else: + masked_kspace = kspace + + # inverse Fourier transform to get zero filled solution + image = ifft2c(masked_kspace) + + # crop input to correct size + if target is not None: + crop_size = (target.shape[-2], target.shape[-1]) + else: + crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) + + # check for sFLAIR 203 + if image.shape[-2] < crop_size[1]: + crop_size = (image.shape[-2], image.shape[-2]) + + image = complex_center_crop(image, crop_size) + # print('image',image.shape) + # absolute value + image = complex_abs(image) + + # apply Root-Sum-of-Squares if multicoil data + if self.which_challenge == "multicoil": + image = rss(image) + + # normalize input + image, mean, std = normalize_instance(image, eps=1e-11) + image = image.clamp(-6, 6) + + # normalize target + if target is not None: + target = to_tensor(target) + target = center_crop(target, crop_size) + target = normalize(target, mean, std, eps=1e-11) + target = target.clamp(-6, 6) + else: + target = torch.Tensor([0]) + + return image, target, mean, std, fname, slice_num + + +def build_transforms(args, mode = 'train'): + + challenge = 'singlecoil' + if mode == 'train': + mask = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + return ReconstructionTransform(challenge, mask, use_seed=False) + elif mode == 'val': + mask = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + return ReconstructionTransform(challenge, mask) + else: + return ReconstructionTransform(challenge) + + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/documents/INSTALL.md b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/documents/INSTALL.md new file mode 100644 index 0000000000000000000000000000000000000000..9912721cb3354240d99c08838ae8d2b1417b339b --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/documents/INSTALL.md @@ -0,0 +1,11 @@ +## Dependency +The code is tested on `python 3.8, Pytorch 1.13`. + +##### Setup environment + +```bash +conda create -n FSMNet python=3.8 +source activate FSMNet # or conda activate FSMNet +pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 +pip install einops h5py matplotlib scikit_image tensorboardX yacs pandas opencv-python timm ml_collections +``` diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/figures/FSMNet.png b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/figures/FSMNet.png new file mode 100644 index 0000000000000000000000000000000000000000..127848f2c580c8d91d9cff8890500e5f3c830d72 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/figures/FSMNet.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40bb9cbda0a8f926ea4ef8d92228ce591766b1d3176000db5758b2edf1a6249b +size 378629 diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/knee_data_split/singlecoil_train_split_less.csv b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/knee_data_split/singlecoil_train_split_less.csv new file mode 100644 index 0000000000000000000000000000000000000000..d85707318750900b14a6e7100541242a60b7a310 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/knee_data_split/singlecoil_train_split_less.csv @@ -0,0 +1,227 @@ +file1000685,file1000568,0.301723929779229 +file1002273,file1000481,0.302226224199571 +file1000472,file1000142,0.304272730770318 +file1002186,file1000863,0.304812175768496 +file1002385,file1002518,0.305357274240413 +file1000981,file1000129,0.305533361411383 +file1001320,file1001948,0.306821514316368 +file1000633,file1002243,0.306892354331709 +file1001872,file1001294,0.308345907393103 +file1001474,file1001830,0.310481695157561 +file1001005,file1001283,0.310497722435023 +file1001690,file1001519,0.310709448786299 +file1002469,file1001811,0.31193137253455 +file1000914,file1000242,0.31237190359308 +file1002284,file1002012,0.315366393843169 +file1001721,file1001328,0.31735122361847 +file1000807,file1002334,0.320096908959039 +file1001944,file1002335,0.320272061156991 +file1002090,file1002431,0.320351887633851 +file1000499,file1002063,0.320786426659383 +file1001362,file1000509,0.32175341740359 +file1001421,file1000597,0.324291432700032 +file1000349,file1000321,0.324545110048573 +file1002123,file1001235,0.327142348994532 +file1001867,file1002086,0.328624781732941 +file1001007,file1001027,0.330759860300298 +file1001915,file1000088,0.331499371283099 +file1001661,file1000313,0.331905252950291 +file1000383,file1000307,0.339998107225229 +file1000116,file1000632,0.34069458535013 +file1002303,file1000173,0.343821267871409 +file1000306,file1001277,0.344751178043605 +file1000003,file1001922,0.346138116633394 +file1000109,file1000143,0.347632265547478 +file1001999,file1000115,0.348248659775587 +file1000089,file1000326,0.348964657514049 +file1001205,file1002232,0.349375610862454 +file1000557,file1000619,0.351305005151048 +file1001823,file1000778,0.352076809462453 +file1000806,file1001130,0.352659078122633 +file1000365,file1000351,0.352772816610486 +file1002374,file1001778,0.352974481603711 +file1002516,file1001910,0.359896103026675 +file1001200,file1000931,0.360070003966827 +file1001479,file1000952,0.360424533696936 +file1000850,file1001942,0.362632797518558 +file1001426,file1002143,0.363271909822866 +file1001304,file1001333,0.36404737582222 +file1000390,file1000518,0.364744579516818 +file1000830,file1002096,0.365897427529429 +file1000794,file1001856,0.365973692948894 +file1001266,file1001327,0.366395851089761 +file1001692,file1002352,0.36655953875445 +file1001564,file1001024,0.367284385415205 +file1001861,file1002050,0.36783497787384 +file1002066,file1002361,0.367964419694875 +file1001613,file1002087,0.368231014746024 +file1001931,file1000220,0.368847112914793 +file1000339,file1000554,0.370123905662701 +file1000754,file1002208,0.37031588493778 +file1001067,file1001956,0.371313060558732 +file1000101,file1001053,0.372141932838775 +file1002520,file1002409,0.372501194473693 +file1001459,file1001615,0.373295536945146 +file1001673,file1000508,0.376416667681519 +file1002201,file1001228,0.376680033570078 +file1000058,file1002449,0.376927627737029 +file1001748,file1001042,0.378067114701689 +file1001941,file1000376,0.37841176147662 +file1000801,file1002545,0.378423759459738 +file1000010,file1000535,0.38111194591455 +file1000882,file1002154,0.382223600234592 +file1001694,file1001297,0.382545161354354 +file1001992,file1002456,0.382664563820782 +file1001666,file1001773,0.382892588770697 +file1001629,file1002514,0.383417073960824 +file1002113,file1000738,0.385439884728523 +file1002221,file1000569,0.385903801966773 +file1002296,file1002117,0.387319754665673 +file1000693,file1001945,0.387855926202209 +file1001410,file1000223,0.391284037867147 +file1002071,file1001425,0.391497653794399 +file1002325,file1001259,0.391913965917762 +file1002430,file1001969,0.392256443856501 +file1002462,file1000708,0.393161981208355 +file1002358,file1001888,0.39427809496515 +file1000485,file1000753,0.395316199436001 +file1002357,file1001973,0.39564210237905 +file1002130,file1002041,0.395978941103639 +file1002569,file1000097,0.397496127623486 +file1002264,file1000148,0.397630184088734 +file1002381,file1001401,0.398105992102355 +file1000289,file1000585,0.399527637723015 +file1002368,file1001723,0.400243022234875 +file1002342,file1001319,0.400431803928825 +file1002170,file1001226,0.400632448147846 +file1001385,file1001758,0.400855988878681 +file1001732,file1002541,0.40091828863264 +file1001102,file1000762,0.400923140595936 +file1001470,file1000181,0.401353492516182 +file1000400,file1000884,0.401562860630016 +file1002293,file1002523,0.401800994807451 +file1000728,file1001654,0.402763341041675 +file1000582,file1001491,0.403451830806034 +file1000586,file1001521,0.403648293267187 +file1002287,file1001770,0.405194821414496 +file1000371,file1000159,0.405999000381268 +file1002356,file1002064,0.406519210876811 +file1000324,file1000590,0.407593694425997 +file1001622,file1001710,0.40759525378577 +file1002037,file1000403,0.407814136488744 +file1002444,file1000743,0.40943197761463 +file1001175,file1002088,0.410423663035312 +file1001391,file1000540,0.410854355646853 +file1002133,file1001186,0.411248429534111 +file1001229,file1001630,0.411355571792039 +file1002283,file1000402,0.411836769927671 +file1000627,file1000161,0.412089060388579 +file1001701,file1001402,0.412854774524637 +file1000795,file1000452,0.413448916432685 +file1000354,file1000947,0.41459642292987 +file1002043,file1002505,0.414863932355455 +file1001285,file1001113,0.418183757940871 +file1000170,file1001832,0.419441549204313 +file1002399,file1001500,0.419905873946513 +file1002439,file1000177,0.42054051043224 +file1001656,file1001217,0.420597020703942 +file1000296,file1000065,0.420845042251081 +file1000626,file1001623,0.42087934790355 +file1001767,file1000760,0.422315537515139 +file1000467,file1001246,0.422371268999111 +file1001033,file1000611,0.42425275873442 +file1002304,file1000221,0.425602179771197 +file1001737,file1001141,0.425716789218234 +file1001565,file1000559,0.426158561043574 +file1000249,file1000643,0.426541100077021 +file1002014,file1001109,0.426587840438723 +file1002006,file1000790,0.427829459781438 +file1000193,file1000750,0.428103808477214 +file1001993,file1001110,0.428186367615143 +file1002094,file1001814,0.428868578868176 +file1000098,file1001420,0.428968675677784 +file1000336,file1000211,0.430347427208789 +file1001498,file1002568,0.43204475404071 +file1001671,file1001106,0.432215802861284 +file1000426,file1002386,0.43283446816702 +file1001520,file1002481,0.434867670495723 +file1002189,file1001432,0.434924370194975 +file1001390,file1002554,0.435313848731387 +file1002166,file1001982,0.435387512979012 +file1001120,file1001006,0.435594761785839 +file1000149,file1001985,0.436289528591294 +file1001632,file1001008,0.436682374331417 +file1002567,file1001155,0.437221000601772 +file1000434,file1002195,0.438098100114814 +file1002532,file1001048,0.438500899539101 +file1001605,file1000927,0.438686659342641 +file1000479,file1000120,0.439587267995034 +file1002473,file1001388,0.439594997597548 +file1001108,file1002228,0.440528754793898 +file1002099,file1002056,0.440776843467602 +file1000191,file1002127,0.441114509542672 +file1000875,file1002494,0.441378135507993 +file1002161,file1000002,0.441912476744187 +file1002269,file1001220,0.442742296865228 +file1001295,file1001355,0.4435162405589 +file1001659,file1001023,0.444686151316673 +file1001857,file1001378,0.447500830900898 +file1001183,file1001370,0.447782748040587 +file1000428,file1000859,0.448328910257083 +file1000588,file1002227,0.448650488897259 +file1001098,file1000486,0.448862467740607 +file1001288,file1000408,0.450363676957042 +file1002097,file1001210,0.451126832474666 +file1000216,file1001082,0.451550143520946 +file1001746,file1001642,0.451781042569196 +file1002388,file1000204,0.451940333555972 +file1000021,file1000560,0.452234621797968 +file1000489,file1001545,0.452796032302523 +file1001116,file1000883,0.453096911915119 +file1001372,file1000561,0.45532542913335 +file1001276,file1000424,0.45534174289324 +file1000974,file1002098,0.455371894001872 +file1002566,file1002044,0.455937677517583 +file1000262,file1002046,0.456056330767294 +file1001619,file1001342,0.456559091350965 +file1000045,file1001616,0.457599407743834 +file1001468,file1002115,0.458095965024278 +file1001061,file1000233,0.460561351667266 +file1000558,file1000100,0.461094222462111 +file1000605,file1000691,0.461429521647285 +file1000640,file1000384,0.463383466503099 +file1000410,file1001358,0.463452482427773 +file1000851,file1001014,0.463558384057952 +file1001092,file1000138,0.463591264436099 +file1000061,file1002049,0.465778207162619 +file1001206,file1000983,0.466701211830884 +file1000256,file1000475,0.466865377968187 +file1002434,file1001387,0.467154181996099 +file1001036,file1000210,0.470404279499276 +file1001540,file1001860,0.472822271037545 +file1001244,file1001154,0.475076170733515 +file1000131,file1001526,0.475459563440874 +file1000180,file1002045,0.476814451110009 +file1001837,file1000637,0.478851985878026 +file1002425,file1001891,0.481451070031007 +file1001056,file1000682,0.482320170742015 +file1002276,file1000777,0.483452141843029 +file1001139,file1002544,0.487462418948035 +file1000548,file1001257,0.488098081542811 +file1000188,file1001286,0.488423105111001 +file1001879,file1000999,0.488449105381724 +file1001062,file1000231,0.48930683373911 +file1000040,file1001873,0.492070802214623 +file1002286,file1000066,0.493213986773381 +file1002474,file1002563,0.501584439120211 +file1000967,file1000563,0.502066261411662 +file1001307,file1002048,0.50460435259807 +file1000483,file1001699,0.511819026566198 +file1001528,file1000285,0.512629017841038 +file1001742,file1002371,0.513805213204644 +file1002397,file1000592,0.515406473057 +file1000069,file1000510,0.528220553613126 +file1001087,file1001300,0.536510449049583 +file1001991,file1000836,0.538145797125916 +file1001382,file1001806,0.538539506621535 +file1000111,file1001189,0.557690760784602 diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/knee_data_split/singlecoil_val_split_less.csv b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/knee_data_split/singlecoil_val_split_less.csv new file mode 100644 index 0000000000000000000000000000000000000000..a1cbac5537562063359f4ac3e0985de51cb989b2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/knee_data_split/singlecoil_val_split_less.csv @@ -0,0 +1,45 @@ +file1000323,file1002538,0.30754967523156 +file1001458,file1001566,0.310512744537048 +file1000885,file1001059,0.318226346221521 +file1000464,file1000196,0.321465466968232 +file1000314,file1000178,0.327505552363568 +file1001163,file1001289,0.328954963947692 +file1000033,file1001191,0.330925609207301 +file1000976,file1000990,0.344036229323198 +file1001930,file1001834,0.345994076497818 +file1002546,file1001344,0.351762252794677 +file1000277,file1001429,0.353297786572139 +file1001893,file1001262,0.358064285890878 +file1000926,file1002067,0.360639004205491 +file1001650,file1002002,0.362186928073579 +file1001184,file1001655,0.362592305723707 +file1001497,file1001338,0.365599407221502 +file1001202,file1001365,0.3844323497275 +file1001126,file1002340,0.388929627976346 +file1001339,file1000291,0.391300537691403 +file1002187,file1001862,0.39883786878841 +file1000041,file1000591,0.39896683485823 +file1001064,file1001850,0.399687813966601 +file1001331,file1002214,0.400340820924839 +file1000831,file1000528,0.403582747590964 +file1000769,file1000538,0.405298051020298 +file1000182,file1001968,0.407646172205036 +file1002382,file1001651,0.410749052045234 +file1000660,file1000476,0.415423894745454 +file1002570,file1001726,0.424622351472032 +file1001585,file1000858,0.426738511964108 +file1000190,file1000593,0.428080574167047 +file1001170,file1001090,0.429987089825525 +file1002252,file1001440,0.432038842370013 +file1000697,file1001144,0.432558506761396 +file1001077,file1000000,0.441922503777368 +file1001381,file1001119,0.455418270809002 +file1001759,file1001851,0.460824505737749 +file1000635,file1002389,0.465674267492171 +file1001668,file1001689,0.467330511330772 +file1001221,file1000818,0.469630000354232 +file1001298,file1002145,0.473526387887779 +file1001763,file1001938,0.47398893150184 +file1001444,file1000942,0.48507438696692 +file1000735,file1002007,0.496530240691134 +file1000477,file1000280,0.528508000547834 diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/metric.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..53ddb27a96bab67975beef06ca6819e628208153 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/metric.py @@ -0,0 +1,51 @@ + +import numpy as np +from skimage.metrics import peak_signal_noise_ratio, structural_similarity + +def nmse(gt, pred): + """Compute Normalized Mean Squared Error (NMSE)""" + return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2 + + +def psnr(gt, pred): + """Compute Peak Signal to Noise Ratio metric (PSNR)""" + return peak_signal_noise_ratio(gt, pred, data_range=gt.max()) + + +def ssim(gt, pred, maxval=None): + """Compute Structural Similarity Index Metric (SSIM)""" + maxval = gt.max() if maxval is None else maxval + + ssim = 0 + for slice_num in range(gt.shape[0]): + ssim = ssim + structural_similarity( + gt[slice_num], pred[slice_num], data_range=maxval + ) + + ssim = ssim / gt.shape[0] + + return ssim + + +class AverageMeter(object): + """Computes and stores the average and current value. + + Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 + """ + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self.score = [] + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + self.score.append(val) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/__init__.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/common_freq.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/common_freq.py new file mode 100644 index 0000000000000000000000000000000000000000..79cf3e778029a846b4da910c115c8315bf33dbaf --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/common_freq.py @@ -0,0 +1,389 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange +class Scale(nn.Module): + + def __init__(self, init_value=1e-3): + super().__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale + + +class invPixelShuffle(nn.Module): + def __init__(self, ratio=2): + super(invPixelShuffle, self).__init__() + self.ratio = ratio + + def forward(self, tensor): + ratio = self.ratio + b = tensor.size(0) + ch = tensor.size(1) + y = tensor.size(2) + x = tensor.size(3) + assert x % ratio == 0 and y % ratio == 0, 'x, y, ratio : {}, {}, {}'.format(x, y, ratio) + tensor = tensor.view(b, ch, y // ratio, ratio, x // ratio, ratio).permute(0, 1, 3, 5, 2, 4) + return tensor.contiguous().view(b, -1, y // ratio, x // ratio) + + +class UpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(in_channels=n_feats, out_channels=4 * n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PixelShuffle(upscale_factor=2)) + m.append(nn.PReLU()) + super(UpSampler, self).__init__(*m) + + +class InvUpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(invPixelShuffle(2)) + m.append(nn.Conv2d(in_channels=n_feats * 4, out_channels=n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PReLU()) + super(InvUpSampler, self).__init__(*m) + + + +class AdaptiveNorm(nn.Module): + def __init__(self, n): + super(AdaptiveNorm, self).__init__() + + self.w_0 = nn.Parameter(torch.Tensor([1.0])) + self.w_1 = nn.Parameter(torch.Tensor([0.0])) + + self.bn = nn.BatchNorm2d(n, momentum=0.999, eps=0.001) + + def forward(self, x): + return self.w_0 * x + self.w_1 * self.bn(x) + + +class ConvBNReLU2D(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, + bias=False, act=None, norm=None): + super(ConvBNReLU2D, self).__init__() + + self.layers = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + self.act = None + self.norm = None + if norm == 'BN': + self.norm = torch.nn.BatchNorm2d(out_channels) + elif norm == 'IN': + self.norm = torch.nn.InstanceNorm2d(out_channels) + elif norm == 'GN': + self.norm = torch.nn.GroupNorm(2, out_channels) + elif norm == 'WN': + self.layers = torch.nn.utils.weight_norm(self.layers) + elif norm == 'Adaptive': + self.norm = AdaptiveNorm(n=out_channels) + + if act == 'PReLU': + self.act = torch.nn.PReLU() + elif act == 'SELU': + self.act = torch.nn.SELU(True) + elif act == 'LeakyReLU': + self.act = torch.nn.LeakyReLU(negative_slope=0.02, inplace=True) + elif act == 'ELU': + self.act = torch.nn.ELU(inplace=True) + elif act == 'ReLU': + self.act = torch.nn.ReLU(True) + elif act == 'Tanh': + self.act = torch.nn.Tanh() + elif act == 'Sigmoid': + self.act = torch.nn.Sigmoid() + elif act == 'SoftMax': + self.act = torch.nn.Softmax2d() + + def forward(self, inputs): + + out = self.layers(inputs) + + if self.norm is not None: + out = self.norm(out) + if self.act is not None: + out = self.act(out) + return out + + +class ResBlock(nn.Module): + def __init__(self, num_features): + super(ResBlock, self).__init__() + self.layers = nn.Sequential( + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, act='ReLU', padding=1), + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, padding=1) + ) + + def forward(self, inputs): + return F.relu(self.layers(inputs) + inputs) + + +class DownSample(nn.Module): + def __init__(self, num_features, act, norm, scale=2): + super(DownSample, self).__init__() + if scale == 1: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + else: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + invPixelShuffle(ratio=scale), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + + def forward(self, inputs): + return self.layers(inputs) + + +class ResidualGroup(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act,norm, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [ + ResBlock(n_feat) for _ in range(n_resblocks)] + + modules_body.append(ConvBNReLU2D(n_feat, n_feat, kernel_size, padding=1, act=act, norm=norm)) + self.body = nn.Sequential(*modules_body) + self.re_scale = Scale(1) + + def forward(self, x): + res = self.body(x) + return res + self.re_scale(x) + + + +class FreBlock9(nn.Module): + def __init__(self, channels, args): + super(FreBlock9, self).__init__() + + self.fpre = nn.Conv2d(channels, channels, 1, 1, 0) + self.amp_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.pha_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.post = nn.Conv2d(channels, channels, 1, 1, 0) + + + def forward(self, x): + # print("x: ", x.shape) + _, _, H, W = x.shape + msF = torch.fft.rfft2(self.fpre(x)+1e-8, norm='backward') + + msF_amp = torch.abs(msF) + msF_pha = torch.angle(msF) + # print("msf_amp: ", msF_amp.shape) + amp_fuse = self.amp_fuse(msF_amp) + # print(amp_fuse.shape, msF_amp.shape) + amp_fuse = amp_fuse + msF_amp + pha_fuse = self.pha_fuse(msF_pha) + pha_fuse = pha_fuse + msF_pha + + real = amp_fuse * torch.cos(pha_fuse)+1e-8 + imag = amp_fuse * torch.sin(pha_fuse)+1e-8 + out = torch.complex(real, imag)+1e-8 + out = torch.abs(torch.fft.irfft2(out, s=(H, W), norm='backward')) + out = self.post(out) + out = out + x + out = torch.nan_to_num(out, nan=1e-5, posinf=1e-5, neginf=1e-5) + # print("out: ", out.shape) + return out + +class Attention(nn.Module): + def __init__(self, dim=64, num_heads=8, bias=False): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) + self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias) + self.q = nn.Conv2d(dim, dim , kernel_size=1, bias=bias) + self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x, y): + b, c, h, w = x.shape + + kv = self.kv_dwconv(self.kv(y)) + k, v = kv.chunk(2, dim=1) + q = self.q_dwconv(self.q(x)) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + +class FuseBlock7(nn.Module): + def __init__(self, channels): + super(FuseBlock7, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fre_att = Attention(dim=channels) + self.spa_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + ori = spa + fre = self.fre(fre) + spa = self.spa(spa) + fre = self.fre_att(fre, spa)+fre + spa = self.fre_att(spa, fre)+spa + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class FuseBlock6(nn.Module): + def __init__(self, channels): + super(FuseBlock6, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock7(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock7, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t1_att = Attention(dim=channels) + self.t2_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + t1 = self.t1_att(t1, t2)+t1 + t2 = self.t2_att(t2, t1)+t2 + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock6(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock6, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/ARTUnet.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/ARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ae23fed16b0d9454a5ea09f17b4322f1e9d7e928 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/ARTUnet.py @@ -0,0 +1,751 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + # print("x shape:", x.shape) + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +########################################################################## +class CS_TransformerBlock(nn.Module): + r""" 在ART Transformer Block的基础上添加了channel attention的分支。 + 参考论文: Dual Aggregation Transformer for Image Super-Resolution + 及其代码: https://github.com/zhengchen1999/DAT/blob/main/basicsr/archs/dat_arch.py + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + + self.dwconv = nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim), + nn.InstanceNorm2d(dim), + nn.GELU()) + + self.channel_interaction = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(dim, dim // 8, kernel_size=1), + # nn.InstanceNorm2d(dim // 8), + nn.GELU(), + nn.Conv2d(dim // 8, dim, kernel_size=1)) + + self.spatial_interaction = nn.Sequential( + nn.Conv2d(dim, dim // 16, kernel_size=1), + nn.InstanceNorm2d(dim // 16), + nn.GELU(), + nn.Conv2d(dim // 16, 1, kernel_size=1)) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # print("x before dwconv:", x.shape) + # convolution output + conv_x = self.dwconv(x.permute(0, 3, 1, 2)) + # conv_x = x.permute(0, 3, 1, 2) + # print("x after dwconv:", x.shape) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + attened_x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + attened_x = attened_x[:, :H, :W, :].contiguous() + + attened_x = attened_x.view(B, H * W, C) + + # Adaptive Interaction Module (AIM) + # C-Map (before sigmoid) + channel_map = self.channel_interaction(conv_x).permute(0, 2, 3, 1).contiguous().view(B, 1, C) + # S-Map (before sigmoid) + attention_reshape = attened_x.transpose(-2,-1).contiguous().view(B, C, H, W) + spatial_map = self.spatial_interaction(attention_reshape) + + # C-I + attened_x = attened_x * torch.sigmoid(channel_map) + # S-I + conv_x = torch.sigmoid(spatial_map) * conv_x + conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, H * W, C) + + x = attened_x + conv_x + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/ARTUnet_new.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/ARTUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9d210d46e2a4d73781b3b16b63a53fdc0f25b9 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/ARTUnet_new.py @@ -0,0 +1,823 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +Unet的格式构建image restoration 网络。每个transformer block中都是dense self-attention和sparse self-attention交替连接。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True, visualize_attention_maps=False): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + self.visualize_attention_maps = visualize_attention_maps + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + # assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + # qkv: 3, B_, num_heads, N, C // num_heads + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + # B_, num_heads, N, C // num_heads * + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + attn_map = attn.detach().cpu().numpy().mean(axis=1) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) # (B * nP, nHead, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + # attn: B * nP, num_heads, N, N + # v: B * nP, num_heads, N, C // num_heads + # -attn-@-v-> B * nP, num_heads, N, C // num_heads + # -transpose-> B * nP, N, num_heads, C // num_heads + # -reshape-> B * nP, N, C + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, G ** 2, G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, Gh * Gw, Gh * Gw) + else: + x = self.attn(x, Gh, Gw, mask=attn_mask) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +class ConcatTransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + ################## + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + ################## + + x = torch.cat((x, y), dim=1) + + shortcut = x + x = self.norm1(x) + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, 2 * Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, 2 * G ** 2, 2 * G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, 2 * Gh * Gw, 2 * Gh * Gw) + else: + x = self.attn(x, 2 * Gh, Gw, mask=attn_mask) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + output = x + + # print(x.shape, Gh, Gw) + x = output[:, :Gh * Gw, :] + y = output[:, Gh * Gw:, :] + assert x.shape == y.shape + + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = y.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = y.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, y, attn_map + else: + return x, y + + def get_patches(self, x, x_size): + H, W = x_size + B, H, W, C = x.shape + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + return x, attn_mask, (Hd, Wd), (Gh, Gw), (pad_r, pad_b), G, I + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/ART_Restormer.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/ART_Restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..4c23123f0ac1a8e82e2bf907658ce8d91951c3e4 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/ART_Restormer.py @@ -0,0 +1,592 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer串联而成。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[4, 4, 4, 4], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + + ######################################### Encoder level1 ############################################ + ### ART encoder block + for layer in self.ART_encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + + ### Restormer encoder block + out_enc_level1 = rearrange(out_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = self.Restormer_encoder_level1(out_enc_level1) + # stack.append(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.ART_encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + out_enc_level2 = rearrange(out_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = self.Restormer_encoder_level2(out_enc_level2) + # stack.append(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.ART_encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + out_enc_level3 = rearrange(out_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = self.Restormer_encoder_level3(out_enc_level3) + # stack.append(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.ART_latent: + latent = layer(latent, [H // 8, W // 8]) + + ### Restormer encoder block + latent = rearrange(latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = self.Restormer_latent(latent) + # stack.append(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + out_dec_level3 = inp_dec_level3 + for layer in self.ART_decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + out_dec_level3 = rearrange(out_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + out_dec_level3 = self.Restormer_decoder_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + out_dec_level2 = inp_dec_level2 + for layer in self.ART_decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + out_dec_level2 = rearrange(out_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + out_dec_level2 = self.Restormer_decoder_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + out_dec_level1 = inp_dec_level1 + for layer in self.ART_decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + ### Restormer decoder block + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_decoder_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_refinement(out_dec_level1) + + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) + + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/ART_Restormer_v2.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/ART_Restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..de8f921e81d79cb7af14a75326f0ddaaf3b3f4c3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/ART_Restormer_v2.py @@ -0,0 +1,663 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer并联而成。 +输入特征分别经过ART和Restormer两个分支提取特征, 将两者的特征concat之后,用conv层变换得到该block的输出特征。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.enc_fuse_level1 = nn.Conv2d(int(dim * 2 ** 1), dim, kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.enc_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.enc_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + self.enc_fuse_level4 = nn.Conv2d(int(dim * 2 ** 4), int(dim * 2 ** 3), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.dec_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.dec_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.dec_fuse_level1 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + # self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + # bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + # self.fuse_refinement = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + # def forward(self, inp_img, aux_img): + # stack = [] + # bs, _, H, W = inp_img.shape + + # fuse_img = torch.cat((inp_img, aux_img), 1) + # inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + + def forward(self, inp_img): + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + + ######################################### Encoder level1 ############################################ + art_enc_level1 = inp_enc_level1 + restor_enc_level1 = inp_enc_level1 + + ### ART encoder block + for layer in self.ART_encoder_level1: + art_enc_level1 = layer(art_enc_level1, [H, W]) + + ### Restormer encoder block + restor_enc_level1 = rearrange(restor_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_enc_level1 = self.Restormer_encoder_level1(restor_enc_level1) ### (b, c, h, w) + + art_enc_level1 = rearrange(art_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = torch.cat((art_enc_level1, restor_enc_level1), 1) + out_enc_level1 = self.enc_fuse_level1(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + art_enc_level2 = inp_enc_level2 + restor_enc_level2 = inp_enc_level2 + + for layer in self.ART_encoder_level2: + art_enc_level2 = layer(art_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + restor_enc_level2 = rearrange(restor_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + restor_enc_level2 = self.Restormer_encoder_level2(restor_enc_level2) + + art_enc_level2 = rearrange(art_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = torch.cat((art_enc_level2, restor_enc_level2), 1) + out_enc_level2 = self.enc_fuse_level2(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + art_enc_level3 = inp_enc_level3 + restor_enc_level3 = inp_enc_level3 + + for layer in self.ART_encoder_level3: + art_enc_level3 = layer(art_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + restor_enc_level3 = rearrange(restor_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + restor_enc_level3 = self.Restormer_encoder_level3(restor_enc_level3) + + art_enc_level3 = rearrange(art_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = torch.cat((art_enc_level3, restor_enc_level3), 1) + out_enc_level3 = self.enc_fuse_level3(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + art_latent = inp_enc_level4 + restor_latent = inp_enc_level4 + + for layer in self.ART_latent: + art_latent = layer(art_latent, [H // 8, W // 8]) + + ### Restormer encoder block + restor_latent = rearrange(restor_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + restor_latent = self.Restormer_latent(restor_latent) + + art_latent = rearrange(art_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = torch.cat((art_latent, restor_latent), 1) + latent = self.enc_fuse_level4(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + art_dec_level3 = inp_dec_level3 + restor_dec_level3 = inp_dec_level3 + + for layer in self.ART_decoder_level3: + art_dec_level3 = layer(art_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + restor_dec_level3 = rearrange(restor_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + restor_dec_level3 = self.Restormer_decoder_level3(restor_dec_level3) + + art_dec_level3 = rearrange(art_dec_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_dec_level3 = torch.cat((art_dec_level3, restor_dec_level3), 1) + out_dec_level3 = self.dec_fuse_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + art_dec_level2 = inp_dec_level2 + restor_dec_level2 = inp_dec_level2 + + for layer in self.ART_decoder_level2: + art_dec_level2 = layer(art_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + restor_dec_level2 = rearrange(restor_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + restor_dec_level2 = self.Restormer_decoder_level2(restor_dec_level2) + + art_dec_level2 = rearrange(art_dec_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_dec_level2 = torch.cat((art_dec_level2, restor_dec_level2), 1) + out_dec_level2 = self.dec_fuse_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + art_dec_level1 = inp_dec_level1 + restor_dec_level1 = inp_dec_level1 + + for layer in self.ART_decoder_level1: + art_dec_level1 = layer(art_dec_level1, [H, W]) + + ### Restormer decoder block + restor_dec_level1 = rearrange(restor_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_dec_level1 = self.Restormer_decoder_level1(restor_dec_level1) + + art_dec_level1 = rearrange(art_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = torch.cat((art_dec_level1, restor_dec_level1), 1) + out_dec_level1 = self.dec_fuse_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +# def build_model(args): +# return ART_Restormer( +# inp_channels=2, +# out_channels=1, +# dim=32, +# num_art_blocks=[2, 4, 4, 6], +# num_restormer_blocks=[2, 2, 2, 2], +# num_refinement_blocks=4, +# heads=[1, 2, 4, 8], +# window_size=[10, 10, 10, 10], mlp_ratio=4., +# qkv_bias=True, qk_scale=None, +# interval=[24, 12, 6, 3], +# ffn_expansion_factor = 2.66, +# LayerNorm_type = 'WithBias', ## Other option 'BiasFree' +# bias=False) + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/ARTfuse_layer.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/ARTfuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..78af0b27528b50db897936f6b8b4eee51d566b1c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/ARTfuse_layer.py @@ -0,0 +1,705 @@ +""" +ART layer in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input to the Attention layer:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + # print("q range:", q.max(), q.min()) + # print("k range:", k.max(), k.min()) + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + # print("attention matrix:", attn.shape, attn) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + # print("feature before proj-layer:", x.max(), x.min()) + x = self.proj(x) + # print("feature after proj-layer:", x.max(), x.min()) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pre_norm=True): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + # self.conv = nn.Conv2d(dim, dim, kernel_size=1) + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + # self.conv = ConvBlock(self.dim, self.dim, drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.bn_norm = nn.BatchNorm2d(dim) + self.pre_norm = pre_norm + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + # x = self.conv(x) + shortcut = x + if self.pre_norm: + # print("using pre_norm") + x = self.norm1(x) ## 归一化之后特征范围变得非常小 + x = x.view(B, H, W, C) + # print("normalized input range:", x.max(), x.min(), x.mean()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + # print("range of feature before layer:", x.max(), x.min(), x.mean()) + # x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + + # x_bn = self.bn_norm(x.permute(0, 3, 1, 2)) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + # print("[ART layer] input and output feature difference:", torch.mean(torch.abs(shortcut - x))) + # print("range of feature before ART layer:", shortcut.max(), shortcut.min(), shortcut.mean()) + # print("range of feature after ART layer:", x.max(), x.min(), x.mean()) + if self.pre_norm: + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + # print("using post_norm") + x = self.norm1(shortcut + self.drop_path(x)) + x = self.norm2(x + self.drop_path(self.mlp(x))) + # x = shortcut + self.drop_path(x) + # x = x + self.drop_path(self.mlp(x)) + + # print("range of fused feature:", x.max(), x.min()) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + return self.layers(image) + + + +########################################################################## +class Cross_TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +class Cross_TransformerBlock_v2(nn.Module): + r""" ART Transformer Block. + 将两个模态的特征沿着channel方向concat, 一起取window. 之后把取出来的两个模态中同一个window的所有patches合并,送到transformer处理。 + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + xy = torch.cat((x, y), -1) + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + xy = F.pad(xy, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + xy = xy.reshape(B, Hd // G, G, Wd // G, G, 2*C).permute(0, 1, 3, 2, 4, 5).contiguous() + xy = xy.reshape(B * Hd * Wd // G ** 2, G ** 2, 2*C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + xy = xy.reshape(B, Gh, I, Gw, I, 2*C).permute(0, 2, 4, 1, 3, 5).contiguous() + xy = xy.reshape(B * I * I, Gh * Gw, 2*C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + x = xy[:, :, :C] + y = xy[:, :, C:] + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/DataConsistency.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/DataConsistency.py new file mode 100644 index 0000000000000000000000000000000000000000..5f460e9acabcb33e1a3f812eb9acaeb337da446c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/DataConsistency.py @@ -0,0 +1,48 @@ +""" +Created: DataConsistency @ Xiyang Cai, 2023/09/09 + +Data consistency layer for k-space signal. + +Ref: DataConsistency in DuDoRNet (https://github.com/bbbbbbzhou/DuDoRNet) + +""" + +import torch +from torch import nn +from einops import repeat + + +def data_consistency(k, k0, mask): + """ + k - input in k-space + k0 - initially sampled elements in k-space + mask - corresponding nonzero location + """ + out = (1 - mask) * k + mask * k0 + return out + + +class DataConsistency(nn.Module): + """ + Create data consistency operator + """ + + def __init__(self): + super(DataConsistency, self).__init__() + + def forward(self, k, k0, mask): + """ + k - input in frequency domain, of shape (n, nx, ny, 2) + k0 - initially sampled elements in k-space + mask - corresponding nonzero location (n, 1, len, 1) + """ + + if k.dim() != 4: # input is 2D + raise ValueError("error in data consistency layer!") + + # mask = repeat(mask.squeeze(1, 3), 'b x -> b x y c', y=k.shape[1], c=2) + mask = torch.tile(mask, (1, mask.shape[2], 1, k.shape[-1])) ### [n, 320, 320, 2] + # print("k and k0 shape:", k.shape, k0.shape) + out = data_consistency(k, k0, mask) + + return out, mask diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/DuDo_mUnet.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/DuDo_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a56069a27f0cdd392482b41dc4e3b79252295487 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/DuDo_mUnet.py @@ -0,0 +1,359 @@ +""" +Xiaohan Xing, 2023/10/25 +传入两个模态的kspace data, 经过kspace network进行重建。 +然后把重建之后的kspace data变换回图像,然后送到image domain的网络进行图像重建。 +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + +from .Kspace_mUnet import kspace_mmUnet + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = kspace_ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = kspace_ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = kspace_ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = kspace_ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class DuDo_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 3, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + # self.kspace_net = mmConvKSpace() + self.kspace_net = kspace_mmUnet(args) + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + ) + ) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + stack = [] + feature_stack = [] + # output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + output = torch.cat((image, LR_image, aux_image), 1) #.detach() + + # print("LR image range:", LR_image.max(), LR_image.min()) + # print("kspace_recon image range:", image.max(), image.min()) + # print("aux_image range:", aux_image.max(), aux_image.min()) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output, image, recon_kspace, masked_kspace, data_stats + + + +class kspace_ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return DuDo_mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/DuDo_mUnet_ARTfusion.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/DuDo_mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6df67b63140ca51bd7b8664151ab4b1b7a6305d5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/DuDo_mUnet_ARTfusion.py @@ -0,0 +1,369 @@ +""" +Xiaohan Xing, 2023/11/08 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + ARTfusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + output1, output2 = aux_image.detach(), LR_image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, image, recon_kspace, masked_kspace, data_stats + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/DuDo_mUnet_CatFusion.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/DuDo_mUnet_CatFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c2b86b6e2031b1af07ab7cb58efc182b7a2673 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/DuDo_mUnet_CatFusion.py @@ -0,0 +1,338 @@ +""" +Xiaohan Xing, 2023/11/10 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + multi layer concat fusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_CATfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t2_gt: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + t2_image = t2_gt * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + # # print("aux_image range:", aux_image.max(), aux_image.min()) + # print("LR_image range:", LR_image.max(), LR_image.min()) + # print("kspace recon_img range:", image.max(), image.min()) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + t2_gt_image = self.normalize(t2_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + t2_gt_image = torch.clamp(t2_gt_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + # output1, output2 = aux_image.detach(), LR_image.detach() + output1, output2 = aux_image.detach(), image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, aux_image, t2_gt_image, image, LR_image, recon_kspace, masked_kspace, data_stats + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_CATfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/Kspace_ConvNet.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/Kspace_ConvNet.py new file mode 100644 index 0000000000000000000000000000000000000000..918cbe598a62653394c8c4eaf99cd10a45c064ca --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/Kspace_ConvNet.py @@ -0,0 +1,140 @@ +""" +2023/10/05 +Xiaohan Xing +Build a simple network with several convolution layers. +Input: concat (undersampled kspace of FS-PD, fully-sampled kspace of PD modality) +Output: interpolated kspace of FS-PD +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, args, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + # self.conv_blocks = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # self.conv_blocks.append(ConvBlock(self.chans, self.chans * 2, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans * 2, self.chans, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans, self.out_chans, drop_prob)) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + # print("output of the three conv blocks:", output1.shape, output2.shape, output3.shape) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("conv1_output range:", output1.max(), output1.min()) + # print("conv2_output range:", output2.max(), output2.min()) + # print("conv3_output range:", output3.max(), output3.min()) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + # print("output before DC layer:", output.shape) + # print("output range before DC layer:", output.max(), output.min()) + # output = output + kspace ## residual connection + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + # mask_output = output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +def build_model(args): + return mmConvKSpace(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/Kspace_mUnet.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/Kspace_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2db9a357798be4abd4cfbd44e9207555770b0456 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/Kspace_mUnet.py @@ -0,0 +1,202 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # print("output:", output.shape) + # print("kspace:", kspace.shape) + # print("mask:", mask.shape) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/Kspace_mUnet_new.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/Kspace_mUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ccfd28dc198ffeff2259a5fbed45dabdc76819 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/Kspace_mUnet_new.py @@ -0,0 +1,200 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # print("using Unet without Relu layers") + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + # output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # # print("output:", output.shape) + # # print("kspace:", kspace.shape) + # # print("mask:", mask.shape) + # mask_output = mask * output + + return output # .permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/MINet.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..4eeeb063b12be1dc594e8e076e6a59e454fcfa0b --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/MINet.py @@ -0,0 +1,356 @@ +""" +Compare with the method "Multi-contrast MRI Super-Resolution via a Multi-stage Integration Network". +The code is from https://github.com/chunmeifeng/MINet/edit/main/fastmri/models/MINet.py. +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F + + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 1 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + # print("modules_tail:", modules_tail) + # self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Sigmoid()) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + + + +class MINet(nn.Module): + def __init__(self, args, n_resgroups=3, n_resblocks=3, n_feats=64): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + # print("x1:", x1.shape, "x2:", x2.shape) + + ### 源代码中是super-resolution任务,所以self.tail对图像进行unsampling. 我们做reconstruction把这步去掉 + x2 = self.tail(x2) + + + resT1 = x1 + resT2 = x2 + + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + # print("t1s:", [item.shape for item in t1s]) + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + # print("resT1 range:", resT1.max(), resT1.min()) + # print("resT2 range:", resT2.max(), resT2.min()) + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1, x2 #x1=pd x2=pdfs + + +def build_model(args): + return MINet(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/MINet_common.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/MINet_common.py new file mode 100644 index 0000000000000000000000000000000000000000..631f3036db4edf8eab00d70c66d2058a3b65cd59 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/MINet_common.py @@ -0,0 +1,90 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias) + +class MeanShift(nn.Conv2d): + def __init__( + self, rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): + + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std + for p in self.parameters(): + p.requires_grad = False + +class BasicBlock(nn.Sequential): + def __init__( + self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, + bn=True, act=nn.ReLU(True)): + + m = [conv(in_channels, out_channels, kernel_size, bias=bias)] + if bn: + m.append(nn.BatchNorm2d(out_channels)) + if act is not None: + m.append(act) + + super(BasicBlock, self).__init__(*m) + +class ResBlock(nn.Module): + def __init__( + self, conv, n_feats, kernel_size, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + + return res + + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + # print("upsampler:", m) + + super(Upsampler, self).__init__(*m) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/SANet.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/SANet.py new file mode 100644 index 0000000000000000000000000000000000000000..bfac3590e248c9532b002885c6c0b7a55ab1400a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/SANet.py @@ -0,0 +1,392 @@ +""" +This script contains the codes of "Exploring Separable Attention for Multi-Contrast MR Image Super-Resolution (TNNLS 2023)" +https://github.com/chunmeifeng/SANet/blob/main/fastmri/models/SANet.py +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import os + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,scale,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups #10 + self.n_resblocks = n_resblocks #20 + self.n_feats = n_feats #64 + kernel_size = 3 + reduction = 16 #16 + # scale = args.scale[0] + + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class Seatt(nn.Module): + def __init__(self, in_c): + super(Seatt, self).__init__() + self.reduce = nn.Conv2d(in_c * 2, 32, 1) + self.ff_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.bf_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.rgbd_pred_layer = Pred_Layer(32 * 2) + self.convq = nn.Conv2d(32, 32, 3, 1, 1) + + def forward(self, rgb_feat, dep_feat, pred): + feat = torch.cat((rgb_feat, dep_feat), 1) + + # vis_attention_map_rgb_feat(rgb_feat[0][0]) + # vis_attention_map_dep_feat(dep_feat[0][0]) + # vis_attention_map_feat(feat[0][0]) + + + feat = self.reduce(feat) + [_, _, H, W] = feat.size() + pred = torch.sigmoid(pred) + + ni_pred = 1 - pred + + ff_feat = self.ff_conv(feat * pred) + bf_feat = self.bf_conv(feat * (1 - pred)) + new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1)) + + return new_pred + +class SANet(nn.Module): + def __init__(self, args, scale=1, n_resgroups=3, n_resblocks=3, n_feats=64): + super(SANet, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + + cs = [64,64,64,64] + self.Seatts = nn.ModuleList([Seatt(c) for c in cs]) + + self.net1 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + print("nlayer:",nlayer) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=3, padding=1) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2, Seatt, map_conv in zip(self.net1.body._modules.items(), self.net2.body._modules.items(), self.Seatts, self.map_convs): + + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + pred = map_conv(resT1) + res = Seatt(resT1,resT2,pred) + + resT2 = res+resT2 + + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + + return x1,x2 + + +def build_model(args): + return SANet(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/SwinFuse_layer.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/SwinFuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3dca8ba793d92961bc3f1d987e211fe542b303 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/SwinFuse_layer.py @@ -0,0 +1,1310 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + # print("x shape:", x.shape) + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + # print("num heads:", self.num_heads) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + # print("x_windows:", x_windows.shape) + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + + # # 不使用残差连接 + # x = self.residual_group_A(x, x_size) + # y = self.residual_group_B(y, x_size) + + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion_layer(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Fusion_num_heads=[6, 6], + window_size=7, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, img_range=1., resi_connection='1conv', + **kwargs): + super(SwinFusion_layer, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + + ### fusion layer. + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + # print("fusion layers:", len(self.layers_Fusion)) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + + def forward_features_Fusion(self, x, y): + + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + # x = torch.cat([x, y], 1) + # ## Downsample the feature in the channel dimension + # x = self.lrelu(self.conv_after_body_Fusion(x)) + # print("x range after fusion:", x.max(), x.min()) + # print("y range after fusion:", y.max(), y.min()) + # print("x:", x.shape) + + return x, y + + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + print("modality1 input:", x.max(), x.min()) + print("modality2 input:", y.max(), y.min()) + + # multi-modal feature fusion. + x, y = self.forward_features_Fusion(x, y) ### 返回两个模态融合之后的features. + print("modality1 output:", x.max(), x.min()) + print("modality2 output:", y.max(), y.min()) + + return x[:, :, :H, :W], y[:, :, :H, :W] + + + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + + +if __name__ == '__main__': + width, height = 120, 120 + swinfuse = SwinFusion_layer(img_size=(width, height), window_size=8, depths=[6, 6, 6], embed_dim=60, num_heads=[6, 6, 6]) + # print("SwinFusion layer:", swinfuse) + + A = torch.randn((4, 60, width, height)) + B = torch.randn((4, 60, width, height)) + + x = swinfuse(A, B) + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/SwinFusion.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/SwinFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..3e72c2a3e8859423f3544043e8b9feaec002d0e2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/SwinFusion.py @@ -0,0 +1,1468 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # print("qkv after reshape:", qkv.shape) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + # 不使用残差连接 + # print("input x shape:", x.shape) + x = self.residual_group_A(x, x_size) + y = self.residual_group_B(y, x_size) + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Re_depths=[4], + Ex_num_heads=[6], Fusion_num_heads=[6, 6], Re_num_heads=[6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinFusion, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + print('in_chans: ', in_chans) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + ####修改shallow feature extraction 网络, 修改为2个3x3的卷积#### + self.conv_first1_A = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first1_B = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first2_A = nn.Conv2d(embed_dim_temp, embed_dim, 3, 1, 1) + self.conv_first2_B = nn.Conv2d(embed_dim_temp, embed_dim_temp, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + self.Re_num_layers = len(Re_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + dpr_Re = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Re_depths))] # stochastic depth decay rule + # build Residual Swin Transformer blocks (RSTB) + self.layers_Ex_A = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_A.append(layer) + self.norm_Ex_A = norm_layer(self.num_features) + + self.layers_Ex_B = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_B.append(layer) + self.norm_Ex_B = norm_layer(self.num_features) + + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + self.layers_Re = nn.ModuleList() + for i_layer in range(self.Re_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Re_depths[i_layer], + num_heads=Re_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Re[sum(Re_depths[:i_layer]):sum(Re_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Re.append(layer) + self.norm_Re = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last1 = nn.Conv2d(embed_dim, embed_dim_temp, 3, 1, 1) + self.conv_last2 = nn.Conv2d(embed_dim_temp, int(embed_dim_temp/2), 3, 1, 1) + self.conv_last3 = nn.Conv2d(int(embed_dim_temp/2), num_out_ch, 3, 1, 1) + + self.tanh = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features_Ex_A(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_A: + x = layer(x, x_size) + + x = self.norm_Ex_A(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward_features_Ex_B(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_B: + x = layer(x, x_size) + + x = self.norm_Ex_B(x) # B L C + x = self.patch_unembed(x, x_size) + return x + + def forward_features_Fusion(self, x, y): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + # print("x token:", x.shape, "y token:", y.shape) + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + x = torch.cat([x, y], 1) + ## Downsample the feature in the channel dimension + x = self.lrelu(self.conv_after_body_Fusion(x)) + + return x + + def forward_features_Re(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Re: + x = layer(x, x_size) + + x = self.norm_Re(x) # B L C + x = self.patch_unembed(x, x_size) + # Convolution + x = self.lrelu(self.conv_last1(x)) + x = self.lrelu(self.conv_last2(x)) + x = self.conv_last3(x) + return x + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + # self.mean_A = self.mean.type_as(x) + # self.mean_B = self.mean.type_as(y) + # self.mean = (self.mean_A + self.mean_B) / 2 + + # x = (x - self.mean_A) * self.img_range + # y = (y - self.mean_B) * self.img_range + + # Feedforward + x = self.forward_features_Ex_A(x) + y = self.forward_features_Ex_B(y) + # print("x before fusion:", x.shape) + # print("y before fusion:", y.shape) + x = self.forward_features_Fusion(x, y) + x = self.forward_features_Re(x) + x = self.tanh(x) + + # x = x / self.img_range + self.mean + # print("H:", H, "upscale:", self.upscale) + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='convs') + + +if __name__ == '__main__': + model = SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='convs') + # print(model) + + A = torch.randn((1, 1, 240, 240)) + B = torch.randn((1, 1, 240, 240)) + + x = model(A, B) + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/TransFuse.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/TransFuse.py new file mode 100644 index 0000000000000000000000000000000000000000..56ca3d2b24f382603c876e4f6909363a9794ffa3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/TransFuse.py @@ -0,0 +1,155 @@ +import math +from collections import deque + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from torchvision import models + + +class SelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + """ + + def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(n_embd, n_embd) + self.query = nn.Linear(n_embd, n_embd) + self.value = nn.Linear(n_embd, n_embd) + # regularization + self.attn_drop = nn.Dropout(attn_pdrop) + self.resid_drop = nn.Dropout(resid_pdrop) + # output projection + self.proj = nn.Linear(n_embd, n_embd) + self.n_head = n_head + + def forward(self, x): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y + + +class Block(nn.Module): + """ an unassuming Transformer block """ + + def __init__(self, n_embd, n_head, block_exp, attn_pdrop, resid_pdrop): + super().__init__() + self.ln1 = nn.LayerNorm(n_embd) + self.ln2 = nn.LayerNorm(n_embd) + self.attn = SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + self.mlp = nn.Sequential( + nn.Linear(n_embd, block_exp * n_embd), + nn.ReLU(True), # changed from GELU + nn.Linear(block_exp * n_embd, n_embd), + nn.Dropout(resid_pdrop), + ) + + def forward(self, x): + B, T, C = x.size() + + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + + return x + + +class TransFuse_layer(nn.Module): + """ the full GPT language model, with a context size of block_size """ + + def __init__(self, n_embd, n_head, block_exp, n_layer, + num_anchors, seq_len=1, + embd_pdrop=0.1, attn_pdrop=0.1, resid_pdrop=0.1): + super().__init__() + self.n_embd = n_embd + self.seq_len = seq_len + self.vert_anchors = num_anchors + self.horz_anchors = num_anchors + + # positional embedding parameter (learnable), image + lidar + self.pos_emb = nn.Parameter(torch.zeros(1, 2 * seq_len * self.vert_anchors * self.horz_anchors, n_embd)) + + self.drop = nn.Dropout(embd_pdrop) + + # transformer + self.blocks = nn.Sequential(*[Block(n_embd, n_head, + block_exp, attn_pdrop, resid_pdrop) + for layer in range(n_layer)]) + + # decoder head + self.ln_f = nn.LayerNorm(n_embd) + + self.block_size = seq_len + self.apply(self._init_weights) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + + def forward(self, m1, m2): + """ + Args: + m1 (tensor): B*seq_len, C, H, W + m2 (tensor): B*seq_len, C, H, W + """ + + bz = m2.shape[0] // self.seq_len + h, w = m2.shape[2:4] + + # forward the image model for token embeddings + m1 = m1.view(bz, self.seq_len, -1, h, w) + m2 = m2.view(bz, self.seq_len, -1, h, w) + + # pad token embeddings along number of tokens dimension + token_embeddings = torch.cat([m1, m2], dim=1).permute(0,1,3,4,2).contiguous() + token_embeddings = token_embeddings.view(bz, -1, self.n_embd) # (B, an * T, C) + + # add (learnable) positional embedding for all tokens + x = self.drop(self.pos_emb + token_embeddings) # (B, an * T, C) + x = self.blocks(x) # (B, an * T, C) + x = self.ln_f(x) # (B, an * T, C) + x = x.view(bz, 2 * self.seq_len, self.vert_anchors, self.horz_anchors, self.n_embd) + x = x.permute(0,1,4,2,3).contiguous() # same as token_embeddings + + m1_out = x[:, :self.seq_len, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + m2_out = x[:, self.seq_len:, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + print("modality1 output:", m1_out.max(), m1_out.min()) + print("modality2 output:", m2_out.max(), m2_out.min()) + + return m1_out, m2_out + + +if __name__ == "__main__": + + feature1 = torch.randn((4, 512, 20, 20)) + feature2 = torch.randn((4, 512, 20, 20)) + model = TransFuse_layer(n_embd=512, n_head=4, block_exp=4, n_layer=8, num_anchors=20, seq_len=1) + print("TransFuse_layer:", model) + feat1, feat2 = model(feature1, feature2) + print(feat1.shape, feat2.shape) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/Unet_ART.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/Unet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b0f834270a921ae4eb428c92b3e782fbed19b3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/Unet_ART.py @@ -0,0 +1,243 @@ +""" +2023/07/06, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + output = conv_layer(output) + + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/__init__.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..334f1821a9b59e692641f70cdc61e7655b1a9c89 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/__init__.py @@ -0,0 +1,114 @@ +from .unet import build_model as UNET +from .mmunet import build_model as MUNET +from .humus_net import build_model as HUMUS +from .MINet import build_model as MINET +from .SANet import build_model as SANET +from .mtrans_net import build_model as MTRANS +from .unimodal_transformer import build_model as TRANS +from .cnn_transformer_new import build_model as CNN_TRANS +from .cnn import build_model as CNN +from .unet_transformer import build_model as UNET_TRANSFORMER +from .swin_transformer import build_model as SWIN_TRANS +from .restormer import build_model as RESTORMER + + +from .cnn_late_fusion import build_model as MULTI_CNN +from .mmunet_late_fusion import build_model as MMUNET_LATE +from .mmunet_early_fusion import build_model as MMUNET_EARLY +from .trans_unet.trans_unet import build_model as TRANSUNET +from .SwinFusion import build_model as SWINMULTI +from .mUnet_transformer import build_model as MUNET_TRANSFORMER +from .munet_multi_transfuse import build_model as MUNET_TRANSFUSE +from .munet_swinfusion import build_model as MUNET_SWINFUSE +from .munet_multi_concat import build_model as MUNET_CONCAT +from .munet_concat_decomp import build_model as MUNET_CONCAT_DECOMP +from .munet_multi_sum import build_model as MUNET_SUM +# from .mUnet_ARTfusion_new import build_model as MUNET_ART_FUSION +from .mUnet_ARTfusion import build_model as MUNET_ART_FUSION +from .mUnet_Restormer_fusion import build_model as MUNET_RESTORMER_FUSION +from .mUnet_ARTfusion_SeqConcat import build_model as MUNET_ART_FUSION_SEQ + +from .ARTUnet import build_model as ART +from .Unet_ART import build_model as UNET_ART +from .unet_restormer import build_model as UNET_RESTORMER +from .mmunet_ART import build_model as MMUNET_ART +from .mmunet_restormer import build_model as MMUNET_RESTORMER +from .mmunet_restormer_v2 import build_model as MMUNET_RESTORMER_V2 +from .mmunet_restormer_3blocks import build_model as MMUNET_RESTORMER_SMALL + +from .mARTUnet import build_model as ART_MULTI_INPUT + +from .ART_Restormer import build_model as ART_RESTORMER +from .ART_Restormer_v2 import build_model as ART_RESTORMER_V2 +from .swinIR import build_model as SWINIR + +from .Kspace_ConvNet import build_model as KSPACE_CONVNET +from .Kspace_mUnet_new import build_model as KSPACE_MUNET +from .kspace_mUnet_concat import build_model as KSPACE_MUNET_CONCAT +from .kspace_mUnet_AttnFusion import build_model as KSPACE_MUNET_ATTNFUSE + +from .DuDo_mUnet import build_model as DUDO_MUNET +from .DuDo_mUnet_ARTfusion import build_model as DUDO_MUNET_ARTFUSION +from .DuDo_mUnet_CatFusion import build_model as DUDO_MUNET_CONCAT + + +model_factory = { + 'unet_single': UNET, + 'humus_single': HUMUS, + 'transformer_single': TRANS, + 'cnn_transformer': CNN_TRANS, + 'cnn_single': CNN, + 'swin_trans_single': SWIN_TRANS, + 'trans_unet': TRANSUNET, + 'unet_transformer': UNET_TRANSFORMER, + 'restormer': RESTORMER, + 'unet_art': UNET_ART, + 'unet_restormer': UNET_RESTORMER, + 'art': ART, + 'art_restormer': ART_RESTORMER, + 'art_restormer_v2': ART_RESTORMER_V2, + 'swinIR': SWINIR, + + + 'munet_transformer': MUNET_TRANSFORMER, + 'munet_transfuse': MUNET_TRANSFUSE, + 'cnn_late_multi': MULTI_CNN, + 'unet_multi': MUNET, + 'unet_late_multi':MMUNET_LATE, + 'unet_early_multi':MMUNET_EARLY, + 'munet_ARTfusion': MUNET_ART_FUSION, + 'munet_restormer_fusion': MUNET_RESTORMER_FUSION, + + + 'minet_multi': MINET, + 'sanet_multi': SANET, + 'mtrans_multi': MTRANS, + 'swin_fusion': SWINMULTI, + 'munet_swinfuse': MUNET_SWINFUSE, + 'munet_concat': MUNET_CONCAT, + 'munet_concat_decomp': MUNET_CONCAT_DECOMP, + 'munet_sum': MUNET_SUM, + 'mmunet_art': MMUNET_ART, + 'mmunet_restormer': MMUNET_RESTORMER, + 'mmunet_restormer_small': MMUNET_RESTORMER_SMALL, + 'mmunet_restormer_v2': MMUNET_RESTORMER_V2, + + 'art_multi_input': ART_MULTI_INPUT, + + 'munet_ARTfusion_SeqConcat': MUNET_ART_FUSION_SEQ, + + 'kspace_ConvNet': KSPACE_CONVNET, + 'kspace_munet': KSPACE_MUNET, + 'kspace_munet_concat': KSPACE_MUNET_CONCAT, + 'kspace_munet_AttnFusion': KSPACE_MUNET_ATTNFUSE, + + 'dudo_munet': DUDO_MUNET, + 'dudo_munet_ARTfusion': DUDO_MUNET_ARTFUSION, + 'dudo_munet_concat': DUDO_MUNET_CONCAT, +} + + +def build_model_from_name(args): + assert args.model_name in model_factory.keys(), 'unknown model name' + + return model_factory[args.model_name](args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/cnn.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5c130178e7cdabaaef8686da0cec341265f3c43d --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/cnn.py @@ -0,0 +1,246 @@ +""" +把our method中的单模态CNN_Transformer中的transformer换成几个conv_block用来下采样提取特征, 看是否能够达到和Unet相似的效果。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Reconstructor(nn.Module): + def __init__(self): + super(CNN_Reconstructor, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + nlatent = self.encoder.output_dim + self.decoder = Decoder(n_upsample=4, n_res=1, dim=nlatent, output_dim=1, pad_type='zero') + + + def forward(self, m): + #4,1,240,240 + feature_output = self.encoder.model(m) # [4, 256, 60, 60] + # print("output of the encoder:", feature_output.shape) + + output = self.decoder(feature_output) + + return output + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Reconstructor() diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/cnn_late_fusion.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/cnn_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..58dacad6910b6e896d00a0d36c9ea60e8d77217f --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/cnn_late_fusion.py @@ -0,0 +1,196 @@ +""" +在lequan的代码上,去掉transformer-based feature fusion部分。 +直接用CNN提取特征之后concat作为multi-modal representation. 用普通的decoder进行图像重建。 +把T1作为guidance modality, T2作为target modality. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers1 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + self.down_sample_layers2 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers1.append(ConvBlock(ch, ch * 2, drop_prob)) + self.down_sample_layers2.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + + # self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # ch = chans + # for _ in range(num_pool_layers - 1): + # self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + # ch *= 2 + # self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv = nn.Sequential(nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), nn.Tanh()) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = torch.cat([image, aux_image], 1) + # for layer in self.down_sample_layers: + # output = layer(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + + # apply down-sampling layers + output = image + for layer in self.down_sample_layers1: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + output_m1 = output + + output = aux_image + for layer in self.down_sample_layers2: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + output_m2 = output + + # print("two modality outputs:", output_m1.shape, output.shape) + output = torch.cat([output_m1, output_m2], 1) + + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv in self.up_transpose_conv: + output = transpose_conv(output) + + output = self.up_conv(output) + # print("model output:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/cnn_transformer.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/cnn_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..261d4d214e35c3a9be5d28bc426ce7280124723c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/cnn_transformer.py @@ -0,0 +1,289 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=2, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=2, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 60 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 4 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + feature_output = self.upsampling1(feature_output) # [4,512,30,30] + feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/cnn_transformer_new.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/cnn_transformer_new.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b2dea1e7cc8d456a684b6c839f052383999e84 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/cnn_transformer_new.py @@ -0,0 +1,290 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +2023/05/23: 之前是从尺寸为60*60的特征图上切patch, 现在可以尝试直接用conv_blocks提取15*15的特征图,然后按照patch_size = 1切成225个patches。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=4, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 1 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + # feature_output = self.upsampling1(feature_output) # [4,512,30,30] + # feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/humus_net.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/humus_net.py new file mode 100644 index 0000000000000000000000000000000000000000..d23701634f849b414866d71b3c23c6d07f3d6437 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/humus_net.py @@ -0,0 +1,1070 @@ +""" +HUMUS-Block +Hybrid Transformer-convolutional Multi-scale denoiser. + +We use parts of the code from +SwinIR (https://github.com/JingyunLiang/SwinIR/blob/main/models/network_swinir.py) +Swin-Unet (https://github.com/HuCaoFighting/Swin-Unet/blob/main/networks/swin_transformer_unet_skip_expand_decoder_sys.py) + +HUMUS-Net中是对kspace data操作, 多个unrolling cascade, 每个cascade中都利用HUMUS-block对图像进行denoising. +因为我们的方法输入是Under-sampled images, 所以不需要进行unrolling, 直接使用一个HUMUS-block进行MRI重建。 +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from einops import rearrange + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # print("x shape:", x.shape) + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + +class PatchExpandSkip(nn.Module): + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.expand = nn.Linear(dim, 2*dim, bias=False) + self.channel_mix = nn.Linear(dim, dim // 2, bias=False) + self.norm = norm_layer(dim // 2) + + def forward(self, x, skip): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) + x = x.view(B,-1,C//4) + x = torch.cat([x, skip], dim=-1) + x = self.channel_mix(x) + x = self.norm(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, torch.tensor(x_size)) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + block_type: + D: downsampling block, + U: upsampling block + B: bottleneck block + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv', block_type='B'): + super(RSTB, self).__init__() + self.dim = dim + conv_dim = dim // (patch_size ** 2) + divide_out_ch = 1 + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + ) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(conv_dim, conv_dim // divide_out_ch, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(conv_dim, conv_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // divide_out_ch, 3, 1, 1)) + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.block_type = block_type + if block_type == 'B': + self.reshape = torch.nn.Identity() + elif block_type == 'D': + self.reshape = PatchMerging(input_resolution, dim) + elif block_type == 'U': + self.reshape = PatchExpandSkip([res // 2 for res in input_resolution], dim * 2) + else: + raise ValueError('Unknown RSTB block type.') + + def forward(self, x, x_size, skip=None): + if self.block_type == 'U': + assert skip is not None, "Skip connection is required for patch expand" + x = self.reshape(x, skip) + + out = self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size)) + block_out = self.patch_embed(out) + x + + if self.block_type == 'D': + block_out = (self.reshape(block_out), block_out) # return skip connection + + return block_out + +class PatchEmbed(nn.Module): + r""" Fixed Image to Patch Embedding. Flattens image along spatial dimensions + without learned projection mapping. Only supports 1x1 patches. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + embed_dim (int): Number of output channels. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + +class PatchEmbedLearned(nn.Module): + r""" Image to Patch Embedding with arbitrary patch size and learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Linear(in_chans * patch_size[0] * patch_size[1], embed_dim) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.patch_to_vec(x) + x = self.proj(x) + if self.norm is not None: + x = self.norm(x) + return x + + def patch_to_vec(self, x): + """ + Args: + x: (B, C, H, W) + Returns: + patches: (B, num_patches, patch_size * patch_size * C) + """ + B, C, H, W = x.shape + x = x.permute(0, 2, 3, 1).contiguous() + x = x.view(B, H // self.patch_size[0], self.patch_size[0], W // self.patch_size[1], self.patch_size[1], C) + patches = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.patch_size[0], self.patch_size[1], C) + patches = patches.contiguous().view(B, self.num_patches, self.patch_size[0] * self.patch_size[1] * C) + return patches + + +class PatchUnEmbed(nn.Module): + r""" Fixed Image to Patch Unembedding via reshaping. Only supports patch size 1x1. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + +class PatchUnEmbedLearned(nn.Module): + r""" Patch Embedding to Image with learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans=None, out_chans=None, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.out_chans = out_chans + + self.proj_up = nn.Linear(in_chans, in_chans * patch_size[0] * patch_size[1]) + + def forward(self, x, x_size=None): + B, HW, C = x.shape + x = self.proj_up(x) + x = self.vec_to_patch(x) + return x + + def vec_to_patch(self, x): + B, HW, C = x.shape + x = x.view(B, self.patches_resolution[0], self.patches_resolution[1], C).contiguous() + x = x.view(B, self.patches_resolution[0], self.patch_size[0], self.patches_resolution[1], self.patch_size[1], C // (self.patch_size[0] * self.patch_size[1])).contiguous() + x = x.view(B, self.img_size[0], self.img_size[1], -1).contiguous().permute(0, 3, 1, 2) + return x + + +class HUMUSNet(nn.Module): + r""" HUMUS-Block + Args: + img_size (int | tuple(int)): Input image size. + patch_size (int | tuple(int)): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depths (tuple(int)): Depth of each Swin Transformer layer in encoder and decoder paths. + num_heads (tuple(int)): Number of attention heads in different layers of encoder and decoder. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + ape (bool): If True, add absolute position embedding to the patch embedding. + patch_norm (bool): If True, add normalization after patch embedding. + use_checkpoint (bool): Whether to use checkpointing to save memory. + img_range: Image range. 1. or 255. + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + conv_downsample_first: use convolutional downsampling before MUST to reduce compute load on Transformers + """ + + def __init__(self, args, + img_size=[240, 240], + in_chans=1, + patch_size=1, + embed_dim=66, + depths=[2, 2, 2], + num_heads=[3, 6, 12], + window_size=5, + mlp_ratio=2., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + img_range=1., + resi_connection='1conv', + bottleneck_depth=2, + bottleneck_heads=24, + conv_downsample_first=True, + out_chans=1, + no_residual_learning=False, + **kwargs): + super(HUMUSNet, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans if out_chans is None else out_chans + + self.center_slice_out = (out_chans == 1) + + self.img_range = img_range + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + self.conv_downsample_first = conv_downsample_first + self.no_residual_learning = no_residual_learning + + ##################################################################################################### + ################################### 1, input block ################################### + input_conv_dim = embed_dim + self.conv_first = nn.Conv2d(num_in_ch, input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** self.num_layers) + self.mlp_ratio = mlp_ratio + + # Downsample for low-res feature extraction + if self.conv_downsample_first: + img_size = [im //2 for im in img_size] + self.conv_down_block = ConvBlock(input_conv_dim // 2, input_conv_dim, 0.0) + self.conv_down = DownsampConvBlock(input_conv_dim, input_conv_dim) + + # split image into non-overlapping patches + if patch_size > 1: + self.patch_embed = PatchEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + if patch_size > 1: + self.patch_unembed = PatchUnEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, out_chans=embed_dim, norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # Build MUST + # encoder + self.layers_down = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** i_layer) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='D', + ) + self.layers_down.append(layer) + + # bottleneck + dim_scaler = (2 ** self.num_layers) + self.layer_bottleneck = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=bottleneck_depth, + num_heads=bottleneck_heads, + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='B', + ) + + # decoder + self.layers_up = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** (self.num_layers - i_layer - 1)) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[(self.num_layers-1-i_layer)], + num_heads=num_heads[(self.num_layers-1-i_layer)], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='U', + ) + self.layers_up.append(layer) + + self.norm_down = norm_layer(self.num_features) + self.norm_up = norm_layer(self.embed_dim) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(input_conv_dim, input_conv_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(input_conv_dim, input_conv_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim, 3, 1, 1)) + + # Upsample if needed + if self.conv_downsample_first: + self.conv_up_block = ConvBlock(input_conv_dim, input_conv_dim // 2, 0.0) + self.conv_up = TransposeConvBlock(input_conv_dim, input_conv_dim // 2) + + ##################################################################################################### + ################################ 3, output block ################################ + self.conv_last = nn.Conv2d(input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, num_out_ch, 3, 1, 1) + self.output_norm = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + + # divisible by window size + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + + # divisible by total downsampling ratio + # this could be done more efficiently by combining the two + total_downsamp = int(2 ** (self.num_layers - 1)) + pad_h = (total_downsamp - h % total_downsamp) % total_downsamp + pad_w = (total_downsamp - w % total_downsamp) % total_downsamp + x = F.pad(x, (0, pad_w, 0, pad_h)) + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + # encode + skip_cons = [] + for i, layer in enumerate(self.layers_down): + x, skip = layer(x, x_size=layer.input_resolution) + skip_cons.append(skip) + x = self.norm_down(x) # B L C + + # bottleneck + x = self.layer_bottleneck(x, self.layer_bottleneck.input_resolution) + + # decode + for i, layer in enumerate(self.layers_up): + x = layer(x, x_size=layer.input_resolution, skip=skip_cons[-i-1]) + x = self.norm_up(x) + x = self.patch_unembed(x, x_size) + return x + + def forward(self, x): + # print("input shape:", x.shape) + # print("input range:", x.max(), x.min()) + C, H, W = x.shape[1:] + center_slice = (C - 1) // 2 + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.conv_downsample_first: + x_first = self.conv_first(x) + x_down = self.conv_down(self.conv_down_block(x_first)) + res = self.conv_after_body(self.forward_features(x_down)) + res = self.conv_up(res) + res = torch.cat([res, x_first], dim=1) + res = self.conv_up_block(res) + + res = self.conv_last(res) + + if self.no_residual_learning: + x = res + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + res + else: + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + + if self.no_residual_learning: + x = self.conv_last(res) + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + self.conv_last(res) + + # print("output range before scaling:", x.max(), x.min()) + # print("self.mean:", self.mean) + + if self.center_slice_out: + x = x / self.img_range + self.mean[:, center_slice, ...].unsqueeze(1) + else: + x = x / self.img_range + self.mean + + x = self.output_norm(x) + + return x[:, :, :H, :W] + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + +class DownsampConvBlock(nn.Module): + """ + A Downsampling Convolutional Block that consists of one strided convolution + layer followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.Conv2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H/2, W/2)`. + """ + return self.layers(image) + + + +def build_model(args): + return HUMUSNet(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/kspace_mUnet_AttnFusion.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/kspace_mUnet_AttnFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..02588c24559c0604ddbbd03a14b3b32e6cff2c8d --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/kspace_mUnet_AttnFusion.py @@ -0,0 +1,300 @@ +""" +Xiaohan Xing, 2023/11/22 +两个模态分别用CNN提取多个层级的特征, 每个层级都把特征沿着宽度拼接起来,得到(h, 2w, c)的特征。 +然后学习得到(h, 2w)的attention map, 和原始的特征相乘送到后面的层。 +""" + +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.attn_layers = nn.ModuleList() + + for l in range(self.num_pool_layers): + + self.attn_layers.append(nn.Sequential(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob), + nn.Conv2d(chans*(2**l), 2, kernel_size=1, stride=1), + nn.Sigmoid())) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, attn_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.attn_layers): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat之后求attention, 然后进行modulation. + attn = attn_layer(torch.cat((output1, output2), 1)) + # print("spatial attention:", attn[:, 0, :, :].unsqueeze(1).shape) + # print("output1:", output1.shape) + output1 = output1 * (1 + attn[:, 0, :, :].unsqueeze(1)) + output2 = output2 * (1 + attn[:, 1, :, :].unsqueeze(1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/kspace_mUnet_concat.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/kspace_mUnet_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..d16bfdc83305d3e1f490e5ca4b445c6d730557ce --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/kspace_mUnet_concat.py @@ -0,0 +1,351 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Predict the low-frequency mask ############## +class Mask_Predictor(nn.Module): + def __init__(self, in_chans, out_chans, chans, drop_prob): + super(Mask_Predictor, self).__init__() + self.conv1 = ConvBlock(in_chans, chans, drop_prob) + self.conv2 = ConvBlock(chans, chans*2, drop_prob) + self.conv3 = ConvBlock(chans*2, chans*4, drop_prob) + + # self.conv4 = ConvBlock(chans*4, 1, drop_prob) + + self.FC = nn.Linear(chans*4, out_chans) + self.act = nn.Sigmoid() + + + def forward(self, x): + ### 三层卷积和down-sampling. + # print("[Mask predictor] input data range:", x.max(), x.min()) + output = F.avg_pool2d(self.conv1(x), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer1_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv2(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer2_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv3(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer3_feature range:", output.max(), output.min()) + + feature = F.adaptive_avg_pool2d(F.relu(output), (1, 1)).squeeze(-1).squeeze(-1) + # print("mask_prediction feature:", output[:, :5]) + # print("[Mask predictor] bottleneck feature range:", feature.max(), feature.min()) + + output = self.FC(feature) + # print("output before act:", output) + output = self.act(output) + # print("mask output:", output) + + return output + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.mask_predictor1 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + self.mask_predictor2 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + ### 预测做low-freq DC的mask区域. + ### 输出三个维度, 前两维是坐标,最后一维是mask内部的权重 + t1_mask = self.mask_predictor1(output1) + t2_mask = self.mask_predictor2(output2) + + # print("t1_mask and t2_mask:", t1_mask.shape, t2_mask.shape) + t1_mask_coords, t2_mask_coords = t1_mask[:, :2], t2_mask[:, :2] + t1_DC_weight, t2_DC_weight = t1_mask[:, -1], t2_mask[:, -1] + + + # ### 预测做low-freq DC的DC weight + # t1_DC_weight = self.mask_predictor1(output1) + # t2_DC_weight = self.mask_predictor2(output2) + + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, t1_mask_coords, t2_mask_coords, t1_DC_weight, t2_DC_weight ## + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mARTUnet.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9607d41ca14562a94e542a1804af5aeaf15bc123 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mARTUnet.py @@ -0,0 +1,563 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +2023/07/20, Xiaohan Xing +将两个模态的图像在输入端concat, 得到num_channels = 2的input, 然后送到ART模型进行T2 modality的重建。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + # print("feature after layer_norm:", x.max(), x.min()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=2, + out_channels=1, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img, aux_img): + stack = [] + bs, _, H, W = inp_img.shape + fuse_img = torch.cat((inp_img, aux_img), 1) + inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=2, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mUnet_ARTfusion.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..610b1e1e38d74f695a3a19f22a6a80f3152bf713 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mUnet_ARTfusion.py @@ -0,0 +1,305 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import CS_TransformerBlock as TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + # output1, output2 = image1, image2 + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mUnet_ARTfusion_SeqConcat.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mUnet_ARTfusion_SeqConcat.py new file mode 100644 index 0000000000000000000000000000000000000000..55280925415f1db47cf746b59fb9710d47c9bff2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mUnet_ARTfusion_SeqConcat.py @@ -0,0 +1,374 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any + +import matplotlib.pyplot as plt +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet_new import ConcatTransformerBlock, TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion_SeqConcat(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4. + ): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.vis_feature_maps = False + self.vis_attention_maps = False + # self.vis_feature_maps = args.MODEL.VIS_FEAT_MAPS + # self.vis_attention_maps = args.MODEL.VIS_ATTN_MAPS + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + if l < 2: + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * (2 ** (l + 1))), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + else: + self.fuse_layers.append(nn.ModuleList([ConcatTransformerBlock(dim=int(chans * (2 ** l)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + feature_maps = { + 'modal1': {'before': [], 'after': []}, + 'modal2': {'before': [], 'after': []} + } + + attention_maps = [] + """ + attention_maps = [ + unet layer 1 attention maps (x2): [map from TransBlock1, TransBlock2, ...], + unet layer 2 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 3 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 4 attention maps (x6): [map from TransBlock1, TransBlock2, ...], + ] + """ + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + if self.vis_feature_maps: + feature_maps['modal1']['before'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['before'].append(output2.detach().cpu().numpy()) + + ### 两个模态的特征concat + # output = torch.cat((output1, output2), 1) + + + + unet_layer_attn_maps = [] + """ + unet_layer_attn_maps: [ + attn_map from TransformerBlock1: (B, nHGroups, nWGroups, winSize, winSize), + attn_map from TransformerBlock2: (B, nHGroups, nWGroups, winSize, winSize), + ... + ] + """ + if l < 2: + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + else: + output1 = rearrange(output1, "b c h w -> b (h w) c").contiguous() + output2 = rearrange(output2, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + if l < 2: + if self.vis_attention_maps: + output, attn_map = layer(output, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output = layer(output, [H // (2**l), W // (2**l)]) + + else: + if self.vis_attention_maps: + output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + attention_maps.append(unet_layer_attn_maps) + + if l < 2: + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + else: + output1 = rearrange(output1, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output2 = rearrange(output2, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + # if self.vis_attention_maps: + # output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) # attn_map: (B, nHGroups, nWGroups, winSize, winSize) + # unet_layer_attn_maps.append(attn_map) + # else: + # output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + # attention_maps.append(unet_layer_attn_maps) + + + + if self.vis_feature_maps: + feature_maps['modal1']['after'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['after'].append(output2.detach().cpu().numpy()) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + if self.vis_feature_maps: + return output1, output2, feature_maps#, relation_stack1, relation_stack2 #, t1_features, t2_features + elif self.vis_attention_maps: + return output1, output2, attention_maps + else: + return output1, output2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion_SeqConcat(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mUnet_Restormer_fusion.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mUnet_Restormer_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8ab29b641fa0231ad0d163224a4a90b44c32a9 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mUnet_Restormer_fusion.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .restormer import TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8]): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.Sequential(*[TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[0], ffn_expansion_factor=2.66, \ + bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + ### 将encoder中multi-modal fusion之后的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = fuse_layer(output) + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mUnet_transformer.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mUnet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2a209758203b198502154b75496333ef91717a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mUnet_transformer.py @@ -0,0 +1,387 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/06/08 +分别用CNN提取两个模态的特征,然后送到transformer进行特征交互, 将经过transformer提取的特征和之前CNN提取的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Decoder_wo_skip(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder_wo_skip, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + output = transpose_conv(output) + output = conv(output) + + return output + + + + + +class mUnet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch*2 + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1, output2 = image1, image2 + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack1, output1 = self.encoder1(output1) ### output size: [4, 256, 15, 15] + stack2, output2 = self.encoder2(output2) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output1) # [4, 225, 256], 225 is the number of patches. + patch_embed_input1 = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + n_embed = self.to_patch_embedding(output2) # [4, 225, 256], 225 is the number of patches. + patch_embed_input2 = F.normalize(n_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input1, patch_embed_input2), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1]/2)), int(np.sqrt(feature_output.shape[1]/2)) + patch_num = feature_output.shape[1] + feature_output1 = feature_output[:, :patch_num//2, :] + feature_output2 = feature_output[:, patch_num//2:, :] + feature_output1 = feature_output1.contiguous().view(b, feature_output1.shape[-1], h, w) # [4,c,15,15] + feature_output2 = feature_output2.contiguous().view(b, feature_output2.shape[-1], h, w) + + output = torch.cat((output1, output2), 1) + torch.cat((feature_output1, feature_output2), 1) + # print("CNN feature range:", output1.max(), output1.min()) + # print("Transformer feature range:", feature_output1.max(), feature_output1.min()) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_Transformer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ee0c4784d04638df4ca259613777dde34980a2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet.py @@ -0,0 +1,201 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/04 +Concatenate the input images from different modalities and input it into the UNet. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_ART.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..1a09f2900a8924cdc1a7c373270e6ff8aadcfa45 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_ART.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_ART_v2.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_ART_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..adfe6bc0d9192a126f39932b40443badb5b60068 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_ART_v2.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_early_fusion.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_early_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..d062f057a2080a471005125a00ee27453a4b0706 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_early_fusion.py @@ -0,0 +1,223 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +encapsulate the encoder and decoder for the Unet model. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + stack, output = self.encoder(output) + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + output = self.decoder(output, stack) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_late_fusion.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..2b48e7d4445afd6a5bc1eaa2b065c37431585e94 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_late_fusion.py @@ -0,0 +1,229 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +use Unet for the reconstruction of each modality, concat the intermediate features of both modalities. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("decoder output:", output.shape) + + return output + + + + +class mmUnet_late(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack1, output1 = self.encoder1(image) + stack2, output2 = self.encoder2(aux_image) + ### fuse two modalities to get multi-modal representation. + output = torch.cat([output1, output2], 1) + output = self.conv(output) ### from [4, 512, 15, 15] to [4, 512, 15, 15] + + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet_late(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_restormer.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..88c7920d65c58ab82a00309a1ae501a8d5b33804 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_restormer.py @@ -0,0 +1,237 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_restormer_3blocks.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_restormer_3blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e3afe68612727d3195872a121ec8040fb0f66ef2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_restormer_3blocks.py @@ -0,0 +1,235 @@ +""" +2023/07/18, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +之前 3 blocks的模型深层特征存在严重的gradient vanishing问题。现在把模型的encoder和decoder都改成只有3个blocks, 看是否还存在gradient vanishing问题。 +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 3, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2] + heads=[1, 2, 4] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_restormer_v2.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..4b51d3a218057206863230973d557173e891d3cd --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mmunet_restormer_v2.py @@ -0,0 +1,220 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +这个版本中是将Unet encoder提取的特征直接连接到decoder, 经过restormer refine的特征只是传递到下一层的encoder block. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + output = feature + + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mtrans_mca.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mtrans_mca.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4c89df4af71e986fa7c24d522e6732371f3128 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mtrans_mca.py @@ -0,0 +1,119 @@ +import torch +from torch import nn + +from .mtrans_transformer import Mlp + +# implement by YunluYan + + +class LinearProject(nn.Module): + def __init__(self, dim_in, dim_out, fn): + super(LinearProject, self).__init__() + need_projection = dim_in != dim_out + self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity() + self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity() + self.fn = fn + + def forward(self, x, *args, **kwargs): + x = self.project_in(x) + x = self.fn(x, *args, **kwargs) + x = self.project_out(x) + return x + + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, dim, num_heads, attn_drop=0., proj_drop=0.): + super(MultiHeadCrossAttention, self).__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim, dim*2, bias=False) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + + def forward(self, x, complement): + + # x [B, HW, C] + B_x, N_x, C_x = x.shape + + x_copy = x + + complement = torch.cat([x, complement], 1) + + B_c, N_c, C_c = complement.shape + + # q [B, heads, HW, C//num_heads] + q = self.to_q(x).reshape(B_x, N_x, self.num_heads, C_x//self.num_heads).permute(0, 2, 1, 3) + kv = self.to_kv(complement).reshape(B_c, N_c, 2, self.num_heads, C_c//self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_x, N_x, C_x) + + x = x + x_copy + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossTransformerEncoderLayer(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio = 1., attn_drop=0., proj_drop=0.,drop_path = 0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super(CrossTransformerEncoderLayer, self).__init__() + self.x_norm1 = norm_layer(dim) + self.c_norm1 = norm_layer(dim) + + self.attn = MultiHeadCrossAttention(dim, num_heads, attn_drop, proj_drop) + + self.x_norm2 = norm_layer(dim) + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop) + + self.drop1 = nn.Dropout(drop_path) + self.drop2 = nn.Dropout(drop_path) + + def forward(self, x, complement): + x = self.x_norm1(x) + complement = self.c_norm1(complement) + + x = x + self.drop1(self.attn(x, complement)) + x = x + self.drop2(self.mlp(self.x_norm2(x))) + return x + + + +class CrossTransformer(nn.Module): + def __init__(self, x_dim, c_dim, depth, num_heads, mlp_ratio =1., attn_drop=0., proj_drop=0., drop_path =0.): + super(CrossTransformer, self).__init__() + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + LinearProject(x_dim, c_dim, CrossTransformerEncoderLayer(c_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)), + LinearProject(c_dim, x_dim, CrossTransformerEncoderLayer(x_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)) + ])) + + def forward(self, x, complement): + for x_attn_complemnt, complement_attn_x in self.layers: + x = x_attn_complemnt(x, complement=complement) + x + complement = complement_attn_x(complement, complement=x) + complement + return x, complement + + + + + + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mtrans_net.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mtrans_net.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e80dadb45725603c7ff1da664a45533575db25 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mtrans_net.py @@ -0,0 +1,145 @@ +""" +This script implement the paper "Multi-Modal Transformer for Accelerated MR Imaging (TMI 2022)" +""" + + +import torch + +from torch import nn +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler +from .mtrans_mca import CrossTransformer + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(INPUT_DIM, HEAD_HIDDEN_DIM): + return ReconstructionHead(INPUT_DIM, HEAD_HIDDEN_DIM) + + +class CrossCMMT(nn.Module): + + def __init__(self, args): + super(CrossCMMT, self).__init__() + + INPUT_SIZE = 240 + INPUT_DIM = 1 # the channel of input + OUTPUT_DIM = 1 # the channel of output + HEAD_HIDDEN_DIM = 16 # the hidden dim of Head + TRANSFORMER_DEPTH = 4 # the depth of the transformer + TRANSFORMER_NUM_HEADS = 4 # the head's num of multi head attention + TRANSFORMER_MLP_RATIO = 3 # the MLP RATIO Of transformer + TRANSFORMER_EMBED_DIM = 256 # the EMBED DIM of transformer + P1 = 8 + P2 = 16 + CTDEPTH = 4 + + + self.head = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + self.head2 = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + + x_patch_dim = HEAD_HIDDEN_DIM * P1 ** 2 + x_num_patches = (INPUT_SIZE // P1) ** 2 + + complement_patch_dim = HEAD_HIDDEN_DIM * P2 ** 2 + complement_num_patches = (INPUT_SIZE // P2) ** 2 + + self.x_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P1, + p2=P1), + ) + + self.complement_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P2, + p2=P2), + ) + + self.x_pos_embedding = nn.Parameter(torch.randn(1, x_num_patches, x_patch_dim)) + self.complement_pos_embedding = nn.Parameter(torch.randn(1, complement_num_patches, complement_patch_dim)) + + ### + self.cross_transformer = CrossTransformer(x_patch_dim, complement_patch_dim, CTDEPTH, + TRANSFORMER_NUM_HEADS, TRANSFORMER_MLP_RATIO) + + self.p1 = P1 + self.p2 = P2 + + self.tail1 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + self.tail2 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x, complement): + x = self.head(x) + complement = self.head2(complement) + + x = x + complement + + b, _, h, w = x.shape + + _, _, c_h, c_w = complement.shape + + ### 得到每个模态的patch embedding (加上了position encoding) + x = self.x_patch_embbeding(x) + x += self.x_pos_embedding + + complement = self.complement_patch_embbeding(complement) + complement += self.complement_pos_embedding + + ### 将两个模态的patch embeddings送到cross transformer中提取特征。 + x, complement = self.cross_transformer(x, complement) + + c = int(x.shape[2] / (self.p1 * self.p1)) + H = int(h / self.p1) + W = int(w / self.p1) + + x = x.reshape(b, H, W, self.p1, self.p1, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w, ) + x = self.tail1(x) + + complement = complement.reshape(b, int(c_h/self.p2), int(c_w/self.p2), self.p2, self.p2, int(complement.shape[2]/self.p2/self.p2)) + complement = complement.permute(0, 5, 1, 3, 2, 4) + complement = complement.reshape(b, -1, c_h, c_w) + + complement = self.tail2(complement) + + return x, complement + + +def build_model(args): + return CrossCMMT(args) + + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mtrans_transformer.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mtrans_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..61314a6a81247b0740a1e98d078242b26ce73f90 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/mtrans_transformer.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(args, embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=args.MODEL.TRANSFORMER_DEPTH, + num_heads=args.MODEL.TRANSFORMER_NUM_HEADS, + mlp_ratio=args.MODEL.TRANSFORMER_MLP_RATIO) + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/munet_concat_decomp.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/munet_concat_decomp.py new file mode 100644 index 0000000000000000000000000000000000000000..07c6f64804dd586927814801b167163f0b65f58f --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/munet_concat_decomp.py @@ -0,0 +1,295 @@ +""" +Xiaohan Xing, 2023/06/25 +两个模态分别用CNN提取多个层级的特征, 每个层级的特征都经过concat之后得到fused feature, +然后每个模态都通过一层conv变换得到2C*h*w的特征, 其中C个channels作为common feature, 另C个channels作为specific features. +利用CDDFuse论文中的decomposition loss约束特征解耦。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.transform_layers1 = nn.ModuleList() + self.transform_layers2 = nn.ModuleList() + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.transform_layers1.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.transform_layers2.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + stack1_common, stack1_specific, stack2_common, stack2_specific = [], [], [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_transform_layer, net2_transform_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.transform_layers1, self.transform_layers2, \ + self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + # print("original feature:", output1.shape) + + ### 分别提取两个模态的common feature和specific feature + output1 = net1_transform_layer(output1) + # print("transformed feature:", output1.shape) + output1_common = output1[:, :(output1.shape[1]//2), :, :] + output1_specific = output1[:, (output1.shape[1]//2):, :, :] + output2 = net2_transform_layer(output2) + output2_common = output2[:, :(output2.shape[1]//2), :, :] + output2_specific = output2[:, (output2.shape[1]//2):, :, :] + + stack1_common.append(output1_common) + stack1_specific.append(output1_specific) + stack2_common.append(output2_common) + stack2_specific.append(output2_specific) + + ### 将一个模态的两组feature和另一模态的common feature合并, 然后经过conv变换得到和初始特征相同维度的fused feature送到下一层。 + output1 = net1_fuse_layer(torch.cat((output1_common, output1_specific, output2_common), 1)) + output2 = net2_fuse_layer(torch.cat((output2_common, output2_specific, output1_common), 1)) + # print("fused feature:", output1.shape) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, stack1_common, stack1_specific, stack2_common, stack2_specific + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/munet_multi_concat.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/munet_multi_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..9e90373e1134210809b98fdc64e3d51d76123855 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/munet_multi_concat.py @@ -0,0 +1,306 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + # output1, output2 = image1, image2 + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # relation1 = get_relation_matrix(output1) + # relation2 = get_relation_matrix(output2) + # relation_stack1.append(relation1) + # relation_stack2.append(relation2) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + # output1 = net1_layer(output1) + # output2 = net2_layer(output2) + + # ### 两个模态的特征concat + # fused_output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + # fused_output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # stack1.append(fused_output1) + # stack2.append(fused_output2) + + # ### downsampling features + # output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + # output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/munet_multi_sum.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/munet_multi_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb4efb13b54ccbeaeb34fc51d3e72b0c50ee242 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/munet_multi_sum.py @@ -0,0 +1,270 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers): + + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + + ## 两个模态的特征求平均. + fused_output1 = (output1 + output2)/2.0 + fused_output2 = (output1 + output2)/2.0 + + output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features #, relation_stack1, relation_stack2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/munet_multi_transfuse.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/munet_multi_transfuse.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ce250a078ba0847c729c1ef1b76452f64d0f07 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/munet_multi_transfuse.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +2023/06/23 +每个层级的特征用TransFuse layer融合之后, 将两个模态的original features和transfuse features concat, +然后在每个模态中经过conv层变换得到下一层的输入。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .TransFuse import TransFuse_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_TransFuse(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + ### conv layer to transform the features after Transfuse layer. + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + self.n_anchors = 15 + self.avgpool = nn.AdaptiveAvgPool2d((self.n_anchors, self.n_anchors)) + self.transfuse_layer1 = TransFuse_layer(n_embd=self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer2 = TransFuse_layer(n_embd=2*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer3 = TransFuse_layer(n_embd=4*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer4 = TransFuse_layer(n_embd=8*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + + self.transfuse_layers = nn.Sequential(self.transfuse_layer1, self.transfuse_layer2, self.transfuse_layer3, self.transfuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, transfuse_layer, net1_fuse_layer, net2_fuse_layer in zip(self.encoder1.down_sample_layers, \ + self.encoder2.down_sample_layers, self.transfuse_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### extract multi-level features from the encoder of each modality. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + ### Transformer-based multi-modal fusion layer + feature1 = self.avgpool(output1) + feature2 = self.avgpool(output2) + feature1, feature2 = transfuse_layer(feature1, feature2) + feature1 = F.interpolate(feature1, scale_factor=output1.shape[-1]//self.n_anchors, mode='bilinear') + feature2 = F.interpolate(feature2, scale_factor=output2.shape[-1]//self.n_anchors, mode='bilinear') + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_TransFuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/munet_swinfusion.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/munet_swinfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6203c4331744f1da70f18e403360196a419b4cb5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/munet_swinfusion.py @@ -0,0 +1,268 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .SwinFuse_layer import SwinFusion_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_SwinFuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过SwinFusion的方式融合。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.fuse_type = 'swinfuse' + self.img_size = 240 + # self.fuse_type = 'sum' + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # Swin transformer fusion modules + + self.fuse_layer1 = SwinFusion_layer(img_size=(self.img_size//2, self.img_size//2), patch_size=4, window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer2 = SwinFusion_layer(img_size=(self.img_size//4, self.img_size//4), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=2*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer3 = SwinFusion_layer(img_size=(self.img_size//8, self.img_size//8), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=4*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer4 = SwinFusion_layer(img_size=(self.img_size//16, self.img_size//16), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=8*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.swinfusion_layers = nn.Sequential(self.fuse_layer1, self.fuse_layer2, self.fuse_layer3, self.fuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, swinfuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.swinfusion_layers): + output1 = net1_layer(output1) + stack1.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # print("output of encoders:", output1.shape, output2.shape) + # print("CNN feature range:", output1.max(), output2.min()) + + ### Swin transformer-based multi-modal fusion. + feature1, feature2 = swinfuse_layer(output1, output2) + output1 = (output1 + feature1 + output2 + feature2)/4.0 + output2 = (output1 + feature1 + output2 + feature2)/4.0 + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_SwinFuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/original_MINet.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/original_MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..5a115872320f677d8b34264eed95c22fff0df9c4 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/original_MINet.py @@ -0,0 +1,335 @@ +from fastmri.models import common +import torch +import torch.nn as nn +import torch.nn.functional as F + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=common.default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 2 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + common.Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.add_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class MINet(nn.Module): + def __init__(self, n_resgroups,n_resblocks, n_feats): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1,x2#x1=pd x2=pdfs + \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/restormer.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..3699cc809e438be4e1bd5efaba7d96e18d7f1602 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/restormer.py @@ -0,0 +1,307 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + # print("feature before FFN projection:", x.max(), x.min()) + x = self.project_out(x) + # print("FFN output:", x.max(), x.min()) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + b,c,h,w = x.shape + # print("Attention block input:", x.max(), x.min()) + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + # print("attention values:", attn) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("Attention block output:", out.max(), out.min()) + + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): + super(TransformerBlock, self).__init__() + + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + # dim = 32, + # num_blocks = [4,4,4,4], + dim = 32, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + + # stack = [] + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + # print("encoder block1 feature:", out_enc_level1.shape, out_enc_level1.max(), out_enc_level1.min()) + # stack.append(out_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + # print("encoder block2 feature:", out_enc_level2.shape, out_enc_level2.max(), out_enc_level2.min()) + # stack.append(out_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + # print("encoder block3 feature:", out_enc_level3.shape, out_enc_level3.max(), out_enc_level3.min()) + # stack.append(out_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + # print("encoder block4 feature:", latent.shape, latent.max(), latent.min()) + # stack.append(latent) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + + return out_dec_level1 #, stack + + + +def build_model(args): + return Restormer() diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/restormer_block.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/restormer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..170ae6a826e5a3e356cb48f8c89500c2d5242816 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/restormer_block.py @@ -0,0 +1,321 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + # print("input to the LayerNorm layer:", x.max(), x.min()) + h, w = x.shape[-2:] + x = to_4d(self.body(to_3d(x)), h, w) + # print("output of the LayerNorm:", x.max(), x.min()) + return x + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + # self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.temperature = 1.0 / ((dim//num_heads) ** 0.5) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + # print("input to the Attention layer:", x.max(), x.min()) + + b,c,h,w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + # print("q range:", q.max(), q.min()) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + # print("scaling parameter:", self.temperature) + # print("attention matrix before softmax:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + attn = attn.softmax(dim=-1) + # print("attention matrix range:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + # print("v range:", v.max(), v.min(), v.mean()) + + out = (attn @ v) + # print("self-attention output:", out.max(), out.min()) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("output of attention layer:", out.max(), out.min()) + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, pre_norm): + super(TransformerBlock, self).__init__() + + self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + self.pre_norm = pre_norm + + def forward(self, x): + x = self.conv(x) + # print("restormer block input:", x.max(), x.min()) + if self.pre_norm: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + else: + x = self.norm1(x + self.attn(x)) + x = self.norm2(x + self.ffn(x)) + # x = self.norm1(self.attn(x)) + # x = self.norm2(self.ffn(x)) + # x = x + self.attn(x) + # x = x + self.ffn(x) + # print("restormer block output:", x.max(), x.min()) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + dim = 48, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, inp_img): + + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + + + return out_dec_level1 + + + + +if __name__ == '__main__': + model = Restormer() + # print(model) + + A = torch.randn((1, 1, 240, 240)) + x = model(A) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/swinIR.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/swinIR.py new file mode 100644 index 0000000000000000000000000000000000000000..5cbfc9c175c02531ac80a05ba4abfc0776f752dd --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/swinIR.py @@ -0,0 +1,874 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 4, 4, 6], + embed_dim=32, num_heads=[6, 6, 6, 6], mlp_ratio=4, upsampler='pixelshuffledirect') + + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = SwinIR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/swin_transformer.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..877bdfbffc91012e4ac31f41b838ebb116f4a825 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/swin_transformer.py @@ -0,0 +1,878 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=1, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + # print("shallow feature:", x.shape) + x = self.conv_after_body(self.forward_features(x)) + x + # print("deep feature:", x.shape) + x = self.upsample(x) + # print("after upsample:", x.shape) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + # return SwinIR(upscale=1, img_size=(240, 240), + # window_size=8, img_range=1., depths=[6, 6, 6, 6], + # embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + +if __name__ == '__main__': + height = 240 + width = 240 + window_size = 8 + model = SwinIR(upscale=1, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + x = torch.randn((1, 1, height, width)) + print("input:", x.shape) + x = model(x) + print("output:", x.shape) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/trans_unet.zip b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/trans_unet.zip new file mode 100644 index 0000000000000000000000000000000000000000..e7b71c1287551fd6119efd7c62da2196307998f5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/trans_unet.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:369de2a28036532c446409ee1f7b953d5763641ffd452675613e1bb9bba89eae +size 19354 diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/trans_unet/trans_unet.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/trans_unet/trans_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..fa61dbadfcaee9b26594307adc90cdd1d2a84e3e --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/trans_unet/trans_unet.py @@ -0,0 +1,489 @@ +# coding=utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import logging +import math + +from os.path import join as pjoin + +import torch +import torch.nn as nn +import numpy as np + +from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm +from torch.nn.modules.utils import _pair +from scipy import ndimage +from . import vit_seg_configs as configs +# import .vit_seg_configs as configs +from .vit_seg_modeling_resnet_skip import ResNetV2 + + +logger = logging.getLogger(__name__) + + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class Attention(nn.Module): + def __init__(self, config, vis): + super(Attention, self).__init__() + self.vis = vis + self.num_attention_heads = config.transformer["num_heads"] + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(config.hidden_size, self.all_head_size) + self.key = Linear(config.hidden_size, self.all_head_size) + self.value = Linear(config.hidden_size, self.all_head_size) + + self.out = Linear(config.hidden_size, config.hidden_size) + self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) + self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + + +class Mlp(nn.Module): + def __init__(self, config): + super(Mlp, self).__init__() + self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) + self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(config.transformer["dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings. + """ + def __init__(self, config, img_size, in_channels=3): + super(Embeddings, self).__init__() + self.hybrid = None + self.config = config + img_size = _pair(img_size) + + if config.patches.get("grid") is not None: # ResNet + grid_size = config.patches["grid"] + # print("image size:", img_size, "grid size:", grid_size) + patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) + patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) + # print("patch_size_real", patch_size_real) + n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) + self.hybrid = True + else: + patch_size = _pair(config.patches["size"]) + n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) + self.hybrid = False + + if self.hybrid: + self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) + in_channels = self.hybrid_model.width * 16 + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=config.hidden_size, + kernel_size=patch_size, + stride=patch_size) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) + + self.dropout = Dropout(config.transformer["dropout_rate"]) + + + def forward(self, x): + # print("input x:", x.shape) ### (4, 3, 240, 240) + if self.hybrid: + x, features = self.hybrid_model(x) + else: + features = None + # print("output of the hybrid model:", x.shape) ### (4, 1024, 15, 15) + x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) + x = x.flatten(2) + x = x.transpose(-1, -2) # (B, n_patches, hidden) + + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings, features + + +class Block(nn.Module): + def __init__(self, config, vis): + super(Block, self).__init__() + self.hidden_size = config.hidden_size + self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn = Mlp(config) + self.attn = Attention(config, vis) + + def forward(self, x): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + def load_from(self, weights, n_block): + ROOT = f"Transformer/encoderblock_{n_block}" + with torch.no_grad(): + query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() + key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() + value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() + out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() + + query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) + key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) + value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) + out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) + + self.attn.query.weight.copy_(query_weight) + self.attn.key.weight.copy_(key_weight) + self.attn.value.weight.copy_(value_weight) + self.attn.out.weight.copy_(out_weight) + self.attn.query.bias.copy_(query_bias) + self.attn.key.bias.copy_(key_bias) + self.attn.value.bias.copy_(value_bias) + self.attn.out.bias.copy_(out_bias) + + mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() + mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() + mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() + mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() + + self.ffn.fc1.weight.copy_(mlp_weight_0) + self.ffn.fc2.weight.copy_(mlp_weight_1) + self.ffn.fc1.bias.copy_(mlp_bias_0) + self.ffn.fc2.bias.copy_(mlp_bias_1) + + self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) + self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) + self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) + self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) + + +class Encoder(nn.Module): + def __init__(self, config, vis): + super(Encoder, self).__init__() + self.vis = vis + self.layer = nn.ModuleList() + self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) + for _ in range(config.transformer["num_layers"]): + layer = Block(config, vis) + self.layer.append(copy.deepcopy(layer)) + + def forward(self, hidden_states): + attn_weights = [] + for layer_block in self.layer: + hidden_states, weights = layer_block(hidden_states) + if self.vis: + attn_weights.append(weights) + encoded = self.encoder_norm(hidden_states) + return encoded, attn_weights + + +class Transformer(nn.Module): + def __init__(self, config, img_size, vis): + super(Transformer, self).__init__() + self.embeddings = Embeddings(config, img_size=img_size) + self.encoder = Encoder(config, vis) + # print(self.encoder) + + def forward(self, input_ids): + embedding_output, features = self.embeddings(input_ids) ### "features" store the features from different layers. + # print("embedding_output:", embedding_output.shape) ## (B, n_patch, hidden) + encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) + # print("encoded feature:", encoded.shape) + return encoded, attn_weights, features + + # return embedding_output, features + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + bn = nn.BatchNorm2d(out_channels) + + super(Conv2dReLU, self).__init__(conv, bn, relu) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels=0, + use_batchnorm=True, + ): + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + + def forward(self, x, skip=None): + x = self.up(x) + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class ReconstructionHead(nn.Sequential): + + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() + super().__init__(conv2d, upsampling) + + +class DecoderCup(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + head_channels = 512 + self.conv_more = Conv2dReLU( + config.hidden_size, + head_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + decoder_channels = config.decoder_channels + in_channels = [head_channels] + list(decoder_channels[:-1]) + out_channels = decoder_channels + + if self.config.n_skip != 0: + skip_channels = self.config.skip_channels + for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip + skip_channels[3-i]=0 + + else: + skip_channels=[0,0,0,0] + + blocks = [ + DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) + ] + self.blocks = nn.ModuleList(blocks) + # print(self.blocks) + + def forward(self, hidden_states, features=None): + B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) + h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) + x = hidden_states.permute(0, 2, 1) + x = x.contiguous().view(B, hidden, h, w) + x = self.conv_more(x) + for i, decoder_block in enumerate(self.blocks): + if features is not None: + skip = features[i] if (i < self.config.n_skip) else None + else: + skip = None + x = decoder_block(x, skip=skip) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): + super(VisionTransformer, self).__init__() + self.num_classes = num_classes + self.zero_head = zero_head + self.classifier = config.classifier + self.transformer = Transformer(config, img_size, vis) + self.decoder = DecoderCup(config) + self.segmentation_head = ReconstructionHead( + in_channels=config['decoder_channels'][-1], + out_channels=config['n_classes'], + kernel_size=3, + ) + self.activation = nn.Tanh() + self.config = config + + def forward(self, x): + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + ### hybrid CNN and transformer to extract features from the input images. + x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) + + # ### extract features with CNN only. + # x, features = self.transformer(x) # (B, n_patch, hidden) + + x = self.decoder(x, features) + # logits = self.segmentation_head(x) + logits = self.activation(self.segmentation_head(x)) + # print("logits:", logits.shape) + return logits + + def load_from(self, weights): + with torch.no_grad(): + + res_weight = weights + self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) + self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) + + self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) + self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) + + posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) + + posemb_new = self.transformer.embeddings.position_embeddings + if posemb.size() == posemb_new.size(): + self.transformer.embeddings.position_embeddings.copy_(posemb) + elif posemb.size()[1]-1 == posemb_new.size()[1]: + posemb = posemb[:, 1:] + self.transformer.embeddings.position_embeddings.copy_(posemb) + else: + logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) + ntok_new = posemb_new.size(1) + if self.classifier == "seg": + _, posemb_grid = posemb[:, :1], posemb[0, 1:] + gs_old = int(np.sqrt(len(posemb_grid))) + gs_new = int(np.sqrt(ntok_new)) + print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb = posemb_grid + self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) + + # Encoder whole + for bname, block in self.transformer.encoder.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=uname) + + if self.transformer.embeddings.hybrid: + self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) + gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) + gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) + self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) + self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) + + for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(res_weight, n_block=bname, n_unit=uname) + +CONFIGS = { + 'ViT-B_16': configs.get_b16_config(), + 'ViT-B_32': configs.get_b32_config(), + 'ViT-L_16': configs.get_l16_config(), + 'ViT-L_32': configs.get_l32_config(), + 'ViT-H_14': configs.get_h14_config(), + 'R50-ViT-B_16': configs.get_r50_b16_config(), + 'R50-ViT-L_16': configs.get_r50_l16_config(), + 'testing': configs.get_testing(), +} + + + +def build_model(args): + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + print(net) + # net.load_from(weights=np.load(config_vit.pretrained_path)) + return net + + +if __name__ == "__main__": + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + net.load_from(weights=np.load(config_vit.pretrained_path)) + + input_data = torch.rand(4, 1, 240, 240).cuda() + print("input:", input_data.shape) + output = net(input_data) + print("output:", output.shape) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/trans_unet/vit_seg_configs.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/trans_unet/vit_seg_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..bed7d3346e30328a38ac6fef1621f788d905df1e --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/trans_unet/vit_seg_configs.py @@ -0,0 +1,132 @@ +import ml_collections + +def get_b16_config(): + """Returns the ViT-B/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 768 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 3072 + config.transformer.num_heads = 12 + config.transformer.num_layers = 4 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + + config.classifier = 'seg' + config.representation_size = None + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' + config.patch_size = 16 + + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_testing(): + """Returns a minimal configuration for testing.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 1 + config.transformer.num_heads = 1 + config.transformer.num_layers = 1 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config + + +def get_r50_b16_config(): + """Returns the Resnet50 + ViT-B/16 configuration.""" + config = get_b16_config() + # config.patches.grid = (16, 16) + config.patches.grid = (15, 15) ### for the input image with size (240, 240) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'recon' + config.pretrained_path = '/home/xiaohan/workspace/MSL_MRI/code/pretrained_model/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 1 + config.n_skip = 3 + config.activation = 'tanh' + + return config + + +def get_b32_config(): + """Returns the ViT-B/32 configuration.""" + config = get_b16_config() + config.patches.size = (32, 32) + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' + return config + + +def get_l16_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1024 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 4096 + config.transformer.num_heads = 16 + config.transformer.num_layers = 24 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.representation_size = None + + # custom + config.classifier = 'seg' + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_r50_l16_config(): + """Returns the Resnet50 + ViT-L/16 configuration. customized """ + config = get_l16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_l32_config(): + """Returns the ViT-L/32 configuration.""" + config = get_l16_config() + config.patches.size = (32, 32) + return config + + +def get_h14_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (14, 14)}) + config.hidden_size = 1280 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 5120 + config.transformer.num_heads = 16 + config.transformer.num_layers = 32 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + + return config \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae80ef7d96c46cb91bda849901aef34008e62ed --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py @@ -0,0 +1,160 @@ +import math + +from os.path import join as pjoin +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +class StdConv2d(nn.Conv2d): + + def forward(self, x): + w = self.weight + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / torch.sqrt(v + 1e-5) + return F.conv2d(x, w, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +def conv3x3(cin, cout, stride=1, groups=1, bias=False): + return StdConv2d(cin, cout, kernel_size=3, stride=stride, + padding=1, bias=bias, groups=groups) + + +def conv1x1(cin, cout, stride=1, bias=False): + return StdConv2d(cin, cout, kernel_size=1, stride=stride, + padding=0, bias=bias) + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + """ + + def __init__(self, cin, cout=None, cmid=None, stride=1): + super().__init__() + cout = cout or cin + cmid = cmid or cout//4 + + self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv1 = conv1x1(cin, cmid, bias=False) + self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! + self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) + self.conv3 = conv1x1(cmid, cout, bias=False) + self.relu = nn.ReLU(inplace=True) + + if (stride != 1 or cin != cout): + # Projection also with pre-activation according to paper. + self.downsample = conv1x1(cin, cout, stride, bias=False) + self.gn_proj = nn.GroupNorm(cout, cout) + + def forward(self, x): + + # Residual branch + residual = x + if hasattr(self, 'downsample'): + residual = self.downsample(x) + residual = self.gn_proj(residual) + + # Unit's branch + y = self.relu(self.gn1(self.conv1(x))) + y = self.relu(self.gn2(self.conv2(y))) + y = self.gn3(self.conv3(y)) + + y = self.relu(residual + y) + return y + + def load_from(self, weights, n_block, n_unit): + conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) + conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) + conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) + + gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) + gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) + + gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) + gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) + + gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) + gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) + + self.conv1.weight.copy_(conv1_weight) + self.conv2.weight.copy_(conv2_weight) + self.conv3.weight.copy_(conv3_weight) + + self.gn1.weight.copy_(gn1_weight.view(-1)) + self.gn1.bias.copy_(gn1_bias.view(-1)) + + self.gn2.weight.copy_(gn2_weight.view(-1)) + self.gn2.bias.copy_(gn2_bias.view(-1)) + + self.gn3.weight.copy_(gn3_weight.view(-1)) + self.gn3.bias.copy_(gn3_bias.view(-1)) + + if hasattr(self, 'downsample'): + proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) + proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) + proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) + + self.downsample.weight.copy_(proj_conv_weight) + self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) + self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode.""" + + def __init__(self, block_units, width_factor): + super().__init__() + width = int(64 * width_factor) + self.width = width + + self.root = nn.Sequential(OrderedDict([ + ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), + ('gn', nn.GroupNorm(32, width, eps=1e-6)), + ('relu', nn.ReLU(inplace=True)), + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) + ])) + + self.body = nn.Sequential(OrderedDict([ + ('block1', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], + ))), + ('block2', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], + ))), + ('block3', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], + ))), + ])) + + def forward(self, x): + features = [] + b, c, in_size, _ = x.size() + x = self.root(x) + features.append(x) + x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) + for i in range(len(self.body)-1): + x = self.body[i](x) + right_size = int(in_size / 4 / (i+1)) + if x.size()[2] != right_size: + pad = right_size - x.size()[2] + assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) + feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) + feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] + else: + feat = x + features.append(feat) + x = self.body[-1](x) + return x, features[::-1] \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/transformer_modules.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/transformer_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..900042884305619e39ad9b792b3b1936dfbd37b3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/transformer_modules.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=4, + num_heads=4, + mlp_ratio=3) + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/unet.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..d6dd935cd34a121d8079a8f5184d4ea29e15c8da --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/unet.py @@ -0,0 +1,196 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F + + +class Unet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = image + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + # print("unet feature range:", output.max(), output.min()) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/unet_restormer.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/unet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f78aa5f192957b2706c6f691dfc66c427b65ae --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/unet_restormer.py @@ -0,0 +1,234 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + # print("Unet feature range:", output.max(), output.min()) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/unet_transformer.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/unet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..982fe23cec321de6b636e04c51b62037054ea432 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/unet_transformer.py @@ -0,0 +1,346 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +对于Unet中的high-level feature, 用Transformer进行特征交互,然后和Transformer之前的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Unet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = image + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack, output = self.encoder(output) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output) # [4, 225, 256], 225 is the number of patches. + patch_embed_input = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1])), int(np.sqrt(feature_output.shape[1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[-1], h, w) # [4,512*2,24,24] + + output = torch.cat((output, feature_output), 1) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output = self.decoder(output, stack) + + return output + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Transformer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/unimodal_transformer.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/unimodal_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c99214e10bee2f7a1be49f8560799ffbc14ed0a5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/compare_models/unimodal_transformer.py @@ -0,0 +1,112 @@ +import torch + +from torch import nn +from .transformer_modules import build_transformer +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(): + return ReconstructionHead(input_dim=1, hidden_dim=12) + + + +class CMMT(nn.Module): + + def __init__(self, args): + super(CMMT, self).__init__() + + self.head = build_head() + + HEAD_HIDDEN_DIM = 12 + PATCH_SIZE = 16 + INPUT_SIZE = 240 + OUTPUT_DIM = 1 + + patch_dim = HEAD_HIDDEN_DIM* PATCH_SIZE ** 2 + num_patches = (INPUT_SIZE // PATCH_SIZE) ** 2 + + self.transformer = build_transformer(patch_dim) + + self.patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=PATCH_SIZE, p2=PATCH_SIZE), + ) + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, patch_dim)) + + + self.p1 = PATCH_SIZE + self.p2 = PATCH_SIZE + + self.tail = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x): + + b,_ , h, w = x.shape + + x = self.head(x) + + x= self.patch_embbeding(x) + + x += self.pos_embedding + + x = self.transformer(x) # b HW p1p2c + + c = int(x.shape[2]/(self.p1*self.p2)) + H = int(h/self.p1) + W = int(w/self.p2) + + x = x.reshape(b, H, W, self.p1, self.p2, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w,) + x = self.tail(x) + + return x + + +def build_model(args): + return CMMT(args) + + + + + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/modules.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..bad92a6d6709130c4ad1cc49d5094958d44d7e71 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/modules.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + +USE_PYTORCH_IN = False + + +###################################################################### +# Superclass of all Modules that take two inputs +###################################################################### +class TwoInputModule(nn.Module): + def forward(self, input1, input2): + raise NotImplementedError + +###################################################################### +# A (sort of) hacky way to create a module that takes two inputs (e.g. x and z) +# and returns one output (say o) defined as follows: +# o = module2.forward(module1.forward(x), z) +# Note that module2 MUST support two inputs as well. +###################################################################### +class MergeModule(TwoInputModule): + def __init__(self, module1, module2): + """ module1 could be any module (e.g. Sequential of several modules) + module2 must accept two inputs + """ + super(MergeModule, self).__init__() + self.module1 = module1 + self.module2 = module2 + + def forward(self, input1, input2): + output1 = self.module1.forward(input1) + output2 = self.module2.forward(output1, input2) + return output2 + +###################################################################### +# A (sort of) hacky way to create a container that takes two inputs (e.g. x and z) +# and applies a sequence of modules (exactly like nn.Sequential) but MergeModule +# is one of its submodules it applies it to both inputs +###################################################################### +class TwoInputSequential(nn.Sequential, TwoInputModule): + def __init__(self, *args): + super(TwoInputSequential, self).__init__(*args) + + def forward(self, input1, input2): + """overloads forward function in parent calss""" + + for module in self._modules.values(): + if isinstance(module, TwoInputModule): + input1 = module.forward(input1, input2) + else: + input1 = module.forward(input1) + return input1 + + +###################################################################### +# A standard instance norm module. +# Since the pytorch instance norm used BatchNorm as a base and thus is +# different from the standard implementation. +###################################################################### +class InstanceNorm(nn.Module): + def __init__(self, num_features, affine=True, eps=1e-5): + """`num_features` number of feature channels + """ + super(InstanceNorm, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + + self.scale = Parameter(torch.Tensor(num_features)) + self.shift = Parameter(torch.Tensor(num_features)) + + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + self.scale.data.normal_(mean=0., std=0.02) + self.shift.data.zero_() + + def forward(self, input): + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + centered_x = x_reshaped - mean + std = torch.rsqrt((centered_x ** 2).mean(2, keepdim=True) + self.eps) + norm_features = (centered_x * std).view(*size) + + # broadcast on the batch dimension, hight and width dimensions + if self.affine: + output = norm_features * self.scale[:,None,None] + self.shift[:,None,None] + else: + output = norm_features + + return output + +InstanceNorm2d = nn.InstanceNorm2d if USE_PYTORCH_IN else InstanceNorm + +###################################################################### +# A module implementing conditional instance norm. +# Takes two inputs: x (input features) and z (latent codes) +###################################################################### +class CondInstanceNorm(TwoInputModule): + def __init__(self, x_dim, z_dim, eps=1e-5): + """`x_dim` dimensionality of x input + `z_dim` dimensionality of z latents + """ + super(CondInstanceNorm, self).__init__() + self.eps = eps + self.shift_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + self.scale_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + + def forward(self, input, noise): + + shift = self.shift_conv.forward(noise) + scale = self.scale_conv.forward(noise) + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + var = x_reshaped.var(2, keepdim=True) + std = torch.rsqrt(var + self.eps) + norm_features = ((x_reshaped - mean) * std).view(*size) + output = norm_features * scale + shift + return output + + +###################################################################### +# A modified resnet block which allows for passing additional noise input +# to be used for conditional instance norm +###################################################################### +class CINResnetBlock(TwoInputModule): + def __init__(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + super(CINResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + for idx, module in enumerate(self.conv_block): + self.add_module(str(idx), module) + + def build_conv_block(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [ + MergeModule( + nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(x_dim, z_dim) + ), + nn.ReLU(True) + ] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + InstanceNorm2d(x_dim, affine=True)] + + return TwoInputSequential(*conv_block) + + def forward(self, x, noise): + out = self.conv_block(x, noise) + out = self.relu(x + out) + return out + +###################################################################### +# Define a resnet block +###################################################################### +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [nn.ReLU(True)] + + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = self.conv_block(x) + out = self.relu(x + out) + return out diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/mynet.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/mynet.py new file mode 100644 index 0000000000000000000000000000000000000000..a3e93c185c773070d07437777b9c01ff11824d4b --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/networks/mynet.py @@ -0,0 +1,388 @@ +import torch +from torch import nn +from networks import common_freq as common + + +class TwoBranch(nn.Module): + def __init__(self, args): + super(TwoBranch, self).__init__() + + num_group = 4 + num_every_group = args.base_num_every_group + self.args = args + + self.init_T2_frq_branch(args) + self.init_T2_spa_branch(args, num_every_group) + self.init_T2_fre_spa_fusion(args) + + self.init_T1_frq_branch(args) + self.init_T1_spa_branch(args, num_every_group) + + self.init_modality_fre_fusion(args) + self.init_modality_spa_fusion(args) + + + def init_T2_frq_branch(self, args): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head_fre = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + + self.down1_fre = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down2_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down2_fre = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down3_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down3_fre = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_neck_fre = [common.FreBlock9(args.num_features, args) + ] + self.neck_fre = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_up1_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up1_fre = nn.Sequential(*modules_up1_fre) + self.up1_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_up2_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up2_fre = nn.Sequential(*modules_up2_fre) + self.up2_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_up3_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up3_fre = nn.Sequential(*modules_up3_fre) + self.up3_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + # define tail module + modules_tail_fre = [ + common.ConvBNReLU2D(args.num_features, out_channels=args.num_channels, kernel_size=3, padding=1, + act=args.act)] + self.tail_fre = nn.Sequential(*modules_tail_fre) + + def init_T2_spa_branch(self, args, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down1 = nn.Sequential(*modules_down1) + + + self.down1_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down2 = nn.Sequential(*modules_down2) + + self.down2_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down3 = nn.Sequential(*modules_down3) + self.down3_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.neck = nn.Sequential(*modules_neck) + + self.neck_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_up1 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.up1 = nn.Sequential(*modules_up1) + + self.up1_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_up2 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.up2 = nn.Sequential(*modules_up2) + self.up2_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + + modules_up3 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.up3 = nn.Sequential(*modules_up3) + self.up3_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + # define tail module + modules_tail = [ + common.ConvBNReLU2D(args.num_features, out_channels=args.num_channels, kernel_size=3, padding=1, + act=args.act)] + + self.tail = nn.Sequential(*modules_tail) + + def init_T2_fre_spa_fusion(self, args): + ### T2 frq & spa fusion part + conv_fuse = [] + for i in range(14): + conv_fuse.append(common.FuseBlock7(args.num_features)) + self.conv_fuse = nn.Sequential(*conv_fuse) + + def init_T1_frq_branch(self, args): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head_fre_T1 = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + + self.down1_fre_T1 = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down2_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down2_fre_T1 = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down3_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down3_fre_T1 = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_neck_fre = [common.FreBlock9(args.num_features, args) + ] + self.neck_fre_T1 = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + def init_T1_spa_branch(self, args, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head_T1 = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down1_T1 = nn.Sequential(*modules_down1) + + + self.down1_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down2_T1 = nn.Sequential(*modules_down2) + + self.down2_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down3_T1 = nn.Sequential(*modules_down3) + self.down3_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.neck_T1 = nn.Sequential(*modules_neck) + + self.neck_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + + def init_modality_fre_fusion(self, args): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(args.num_features)) + self.conv_fuse_fre = nn.Sequential(*conv_fuse) + + def init_modality_spa_fusion(self, args): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(args.num_features)) + self.conv_fuse_spa = nn.Sequential(*conv_fuse) + + def forward(self, main, aux): + #### T1 fre encoder + t1_fre = self.head_fre_T1(aux) # 128 + + down1_fre_t1 = self.down1_fre_T1(t1_fre)# 64 + down1_fre_mo_t1 = self.down1_fre_mo_T1(down1_fre_t1) + + down2_fre_t1 = self.down2_fre_T1(down1_fre_mo_t1) # 32 + down2_fre_mo_t1 = self.down2_fre_mo_T1(down2_fre_t1) + + down3_fre_t1 = self.down3_fre_T1(down2_fre_mo_t1) # 16 + down3_fre_mo_t1 = self.down3_fre_mo_T1(down3_fre_t1) + + neck_fre_t1 = self.neck_fre_T1(down3_fre_mo_t1) # 16 + neck_fre_mo_t1 = self.neck_fre_mo_T1(neck_fre_t1) + + + #### T2 fre encoder and T1 & T2 fre fusion + x_fre = self.head_fre(main) # 128 + x_fre_fuse = self.conv_fuse_fre[0](t1_fre, x_fre) + + down1_fre = self.down1_fre(x_fre_fuse)# 64 + down1_fre_mo = self.down1_fre_mo(down1_fre) + down1_fre_mo_fuse = self.conv_fuse_fre[1](down1_fre_mo_t1, down1_fre_mo) + + down2_fre = self.down2_fre(down1_fre_mo_fuse) # 32 + down2_fre_mo = self.down2_fre_mo(down2_fre) + down2_fre_mo_fuse = self.conv_fuse_fre[2](down2_fre_mo_t1, down2_fre_mo) + + down3_fre = self.down3_fre(down2_fre_mo_fuse) # 16 + down3_fre_mo = self.down3_fre_mo(down3_fre) + down3_fre_mo_fuse = self.conv_fuse_fre[3](down3_fre_mo_t1, down3_fre_mo) + + neck_fre = self.neck_fre(down3_fre_mo_fuse) # 16 + neck_fre_mo = self.neck_fre_mo(neck_fre) + neck_fre_mo_fuse = self.conv_fuse_fre[4](neck_fre_mo_t1, neck_fre_mo) + + + #### T2 fre decoder + neck_fre_mo = neck_fre_mo_fuse + down3_fre_mo_fuse + + up1_fre = self.up1_fre(neck_fre_mo) # 32 + up1_fre_mo = self.up1_fre_mo(up1_fre) + up1_fre_mo = up1_fre_mo + down2_fre_mo_fuse + + up2_fre = self.up2_fre(up1_fre_mo) # 64 + up2_fre_mo = self.up2_fre_mo(up2_fre) + up2_fre_mo = up2_fre_mo + down1_fre_mo_fuse + + up3_fre = self.up3_fre(up2_fre_mo) # 128 + up3_fre_mo = self.up3_fre_mo(up3_fre) + up3_fre_mo = up3_fre_mo + x_fre_fuse + + res_fre = self.tail_fre(up3_fre_mo) + + #### T1 spa encoder + x_t1 = self.head_T1(aux) # 128 + + down1_t1 = self.down1_T1(x_t1) # 64 + down1_mo_t1 = self.down1_mo_T1(down1_t1) + + down2_t1 = self.down2_T1(down1_mo_t1) # 32 + down2_mo_t1 = self.down2_mo_T1(down2_t1) # 32 + + down3_t1 = self.down3_T1(down2_mo_t1) # 16 + down3_mo_t1 = self.down3_mo_T1(down3_t1) # 16 + + neck_t1 = self.neck_T1(down3_mo_t1) # 16 + neck_mo_t1 = self.neck_mo_T1(neck_t1) + + #### T2 spa encoder and fusion + x = self.head(main) # 128 + + x_fuse = self.conv_fuse_spa[0](x_t1, x) + down1 = self.down1(x_fuse) # 64 + down1_fuse = self.conv_fuse[0](down1_fre, down1) + down1_mo = self.down1_mo(down1_fuse) + down1_fuse_mo = self.conv_fuse[1](down1_fre_mo_fuse, down1_mo) + + down1_fuse_mo_fuse = self.conv_fuse_spa[1](down1_mo_t1, down1_fuse_mo) + down2 = self.down2(down1_fuse_mo_fuse) # 32 + down2_fuse = self.conv_fuse[2](down2_fre, down2) + down2_mo = self.down2_mo(down2_fuse) # 32 + down2_fuse_mo = self.conv_fuse[3](down2_fre_mo, down2_mo) + + down2_fuse_mo_fuse = self.conv_fuse_spa[2](down2_mo_t1, down2_fuse_mo) + down3 = self.down3(down2_fuse_mo_fuse) # 16 + down3_fuse = self.conv_fuse[4](down3_fre, down3) + down3_mo = self.down3_mo(down3_fuse) # 16 + down3_fuse_mo = self.conv_fuse[5](down3_fre_mo, down3_mo) + + down3_fuse_mo_fuse = self.conv_fuse_spa[3](down3_mo_t1, down3_fuse_mo) + neck = self.neck(down3_fuse_mo_fuse) # 16 + neck_fuse = self.conv_fuse[6](neck_fre, neck) + neck_mo = self.neck_mo(neck_fuse) + neck_mo = neck_mo + down3_mo + neck_fuse_mo = self.conv_fuse[7](neck_fre_mo, neck_mo) + + neck_fuse_mo_fuse = self.conv_fuse_spa[4](neck_mo_t1, neck_fuse_mo) + #### T2 spa decoder + up1 = self.up1(neck_fuse_mo_fuse) # 32 + up1_fuse = self.conv_fuse[8](up1_fre, up1) + up1_mo = self.up1_mo(up1_fuse) + up1_mo = up1_mo + down2_mo + up1_fuse_mo = self.conv_fuse[9](up1_fre_mo, up1_mo) + + up2 = self.up2(up1_fuse_mo) # 64 + up2_fuse = self.conv_fuse[10](up2_fre, up2) + up2_mo = self.up2_mo(up2_fuse) + up2_mo = up2_mo + down1_mo + up2_fuse_mo = self.conv_fuse[11](up2_fre_mo, up2_mo) + + up3 = self.up3(up2_fuse_mo) # 128 + + up3_fuse = self.conv_fuse[12](up3_fre, up3) + up3_mo = self.up3_mo(up3_fuse) + + up3_mo = up3_mo + x + up3_fuse_mo = self.conv_fuse[13](up3_fre_mo, up3_mo) + # import matplotlib.pyplot as plt + # plt.axis('off') + # plt.imshow((255*up3_fre_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_fre_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + + # plt.axis('off') + # plt.imshow((255*up3_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + + # plt.axis('off') + # plt.imshow((255*up3_fuse_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_fuse_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + # breakpoint() + + res = self.tail(up3_fuse_mo) + + return {'img_out': res + main, 'img_fre': res_fre + main} + +def make_model(args): + return TwoBranch(args) + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/option.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/option.py new file mode 100644 index 0000000000000000000000000000000000000000..f6822c0797cdf1191be2a2c6c16842d65d3b8138 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/option.py @@ -0,0 +1,62 @@ +import argparse + +parser = argparse.ArgumentParser(description='MRI recon') +parser.add_argument('--root_path', type=str, default='/home/xiaohan/datasets/BRATS_dataset/BRATS_2020_images/selected_images/') +parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') +parser.add_argument('--low_field_SNR', type=int, default=15, help='SNR of the simulated low-field image') +parser.add_argument('--phase', type=str, default='train', help='Name of phase') +parser.add_argument('--gpu', type=str, default='0', help='GPU to use') +parser.add_argument('--exp', type=str, default='msl_model', help='model_name') +parser.add_argument('--max_iterations', type=int, default=100000, help='maximum epoch number to train') +parser.add_argument('--batch_size', type=int, default=8, help='batch_size per gpu') +parser.add_argument('--base_lr', type=float, default=0.0002, help='maximum epoch numaber to train') +parser.add_argument('--seed', type=int, default=1337, help='random seed') +parser.add_argument('--resume', type=str, default=None, help='resume') +parser.add_argument('--relation_consistency', type=str, default='False', help='regularize the consistency of feature relation') +parser.add_argument('--clip_grad', type=str, default='True', help='clip gradient of the network parameters') + + +parser.add_argument('--norm', type=str, default='False', help='Norm Layer between UNet and Transformer') +parser.add_argument('--input_normalize', type=str, default='mean_std', help='choose from [min_max, mean_std, divide]') + +parser.add_argument("--dist_url", default="63654") + +parser.add_argument('--scale', type=int, default=8, + help='super resolution scale') +parser.add_argument('--base_num_every_group', type=int, default=2, + help='super resolution scale') + + +parser.add_argument('--rgb_range', type=int, default=255, + help='maximum value of RGB') +parser.add_argument('--n_colors', type=int, default=3, + help='number of color channels to use') +parser.add_argument('--augment', action='store_true', + help='use data augmentation') +parser.add_argument('--fftloss', action='store_true', + help='use data augmentation') +parser.add_argument('--fftd', action='store_true', + help='use data augmentation') +parser.add_argument('--fftd_weight', type=float, default=0.1, + help='use data augmentation') +parser.add_argument('--fft_weight', type=float, default=0.01) + +# Model specifications +parser.add_argument('--model', type=str, default='MYNET') +parser.add_argument('--act', type=str, default='PReLU') +parser.add_argument('--data_range', type=float, default=1) +parser.add_argument('--num_channels', type=int, default=1) +parser.add_argument('--num_features', type=int, default=64) + +parser.add_argument('--n_feats', type=int, default=64, + help='number of feature maps') +parser.add_argument('--res_scale', type=float, default=0.2, + help='residual scaling') + +parser.add_argument('--MASKTYPE', type=str, default='random') # "random" or "equispaced" +parser.add_argument('--CENTER_FRACTIONS', nargs='+', type=float) +parser.add_argument('--ACCELERATIONS', nargs='+', type=int) + + + +args = parser.parse_args() diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/test_brats.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/test_brats.py new file mode 100644 index 0000000000000000000000000000000000000000..f371d55781cb361124387c7c651d5b133a2f5600 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/test_brats.py @@ -0,0 +1,150 @@ +import os +import sys +from tqdm import tqdm +import shutil +import argparse +import logging +import numpy as np +from skimage import io +from scipy.ndimage import zoom + +import torch +import torch.nn as nn +from torchvision import transforms +from torch.utils.data import DataLoader +from networks.compare_models import build_model_from_name +from dataloaders.BRATS_dataloader_new import Hybrid as MyDataset +from dataloaders.BRATS_dataloader_new import ToTensor +from networks.mynet import TwoBranch +from utils import bright, trunc +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + + +parser = argparse.ArgumentParser() +parser.add_argument('--root_path', type=str, default='/home/xiaohan/datasets/BRATS_dataset/BRATS_2020_images/selected_images/') +parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') +parser.add_argument('--low_field_SNR', type=int, default=15, help='SNR of the simulated low-field image') +parser.add_argument('--phase', type=str, default='test', help='Name of phase') +parser.add_argument('--gpu', type=str, default='0', help='GPU to use') +parser.add_argument('--exp', type=str, default='msl_model', help='model_name') +parser.add_argument('--seed', type=int, default=1337, help='random seed') +parser.add_argument('--base_lr', type=float, default=0.0002, help='maximum epoch numaber to train') + +parser.add_argument('--model_name', type=str, default='unet_single', help='model_name') +parser.add_argument('--relation_consistency', type=str, default='False', help='regularize the consistency of feature relation') +parser.add_argument('--norm', type=str, default='False', help='Norm Layer between UNet and Transformer') +parser.add_argument('--input_normalize', type=str, default='mean_std', help='choose from [min_max, mean_std, divide]') + +# args = parser.parse_args() +from option import args +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + + +def normalize_output(out_img): + out_img = (out_img - out_img.min())/(out_img.max() - out_img.min() + 1e-8) + return out_img + + +if __name__ == "__main__": + ## make logger file + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + network = TwoBranch(args).cuda() + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + + db_test = MyDataset(split='test', MRIDOWN=args.MRIDOWN, SNR=args.low_field_SNR, + transform=transforms.Compose([ToTensor()]), + base_dir=test_data_path, input_normalize = args.input_normalize) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=2, pin_memory=True) + + if args.phase == 'test': + + save_mode_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + print('load weights from ' + save_mode_path) + checkpoint = torch.load(save_mode_path) + network.load_state_dict(checkpoint['network']) + network.eval() + cnt = 0 + save_path = snapshot_path + '/result_case/' + feature_save_path = snapshot_path + '/feature_visualization/' + if not os.path.exists(save_path): + os.makedirs(save_path) + if not os.path.exists(feature_save_path): + os.makedirs(feature_save_path) + + + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all = [], [], [] + + for (sampled_batch, sample_stats) in tqdm(testloader, ncols=70): + cnt += 1 + + print('processing ' + str(cnt) + ' image') + t1_in, t1, t2_in, t2 = sampled_batch['image_in'].cuda(), sampled_batch['image'].cuda(), \ + sampled_batch['target_in'].cuda(), sampled_batch['target'].cuda() + t1_krecon, t2_krecon = sampled_batch['image_krecon'].cuda(), sampled_batch['target_krecon'].cuda() + + t2_mean = sample_stats['t2_mean'].data.cpu().numpy()[0] + t2_std = sample_stats['t2_std'].data.cpu().numpy()[0] + + + t1_out, t2_out = None, None + + + t2_out = network(t2_in, t1_in)['img_out'] + t2_out_2 = network(t2_in, t1_in)['img_out'] + + t1_mean = sample_stats['t1_mean'].data.cpu().numpy()[0] + t1_std = sample_stats['t1_std'].data.cpu().numpy()[0] + t2_mean = sample_stats['t2_mean'].data.cpu().numpy()[0] + t2_std = sample_stats['t2_std'].data.cpu().numpy()[0] + + if t1_out is not None: + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_out_img = (np.clip(t1_out.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_krecon_img = (np.clip(t1_krecon.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t2_in_img = (np.clip(t2_in.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_img = (np.clip(t2.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_out_img = (np.clip(t2_out.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_krecon_img = (np.clip(t2_krecon.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_out_2_img = (np.clip(t2_out_2.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + + io.imsave(save_path + str(cnt) + '_t1.png', bright(t1_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2.png', bright(t2_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2_in.png', bright(t2_in_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2_out.png', bright(t2_out_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2_out2.png', bright(t2_out_2_img,0,0.8)) + + + if t2_out is not None: + t2_out_img[t2_out_img < 0.0] = 0.0 + t2_img[t2_img < 0.0] = 0.0 + MSE = mean_squared_error(t2_img, t2_out_img) + PSNR = peak_signal_noise_ratio(t2_img, t2_out_img) + SSIM = structural_similarity(t2_img, t2_out_img) + t2_MSE_all.append(MSE) + t2_PSNR_all.append(PSNR) + t2_SSIM_all.append(SSIM) + print("[t2 MRI] MSE:", MSE, "PSNR:", PSNR, "SSIM:", SSIM) + + # if cnt > 20: + # break + + print("[T2 MRI:] average MSE:", np.array(t2_MSE_all).mean(), "average PSNR:", np.array(t2_PSNR_all).mean(), "average SSIM:", np.array(t2_SSIM_all).mean()) + print("[T2 MRI:] average MSE:", np.array(t2_MSE_all).std(), "average PSNR:", np.array(t2_PSNR_all).std(), "average SSIM:", np.array(t2_SSIM_all).std()) + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/test_fastmri.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/test_fastmri.py new file mode 100644 index 0000000000000000000000000000000000000000..10153dd8b437436019f5217abeaf13da54fc4e37 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/test_fastmri.py @@ -0,0 +1,168 @@ +import os +import sys +from tqdm import tqdm +import shutil +import argparse +import logging +import numpy as np +from skimage import io +from scipy.ndimage import zoom + +import torch +import torch.nn as nn +from torchvision import transforms +from torch.utils.data import DataLoader +from networks.compare_models import build_model_from_name +from dataloaders.BRATS_dataloader_new import Hybrid as MyDataset +from dataloaders.BRATS_dataloader_new import ToTensor +from networks.mynet import TwoBranch +from utils import bright, trunc +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + +from option import args + +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + + +def normalize_output(out_img): + out_img = (out_img - out_img.min())/(out_img.max() - out_img.min() + 1e-8) + return out_img + +from metric import nmse, psnr, ssim, AverageMeter +from collections import defaultdict +@torch.no_grad() +def evaluate(model, data_loader, device, save_path): + os.makedirs(save_path, exist_ok=True) + model.eval() + nmse_meter = AverageMeter() + psnr_meter = AverageMeter() + ssim_meter = AverageMeter() + output_dic = defaultdict(dict) + target_dic = defaultdict(dict) + input_dic = defaultdict(dict) + flag=0 + last_name='no' + for data in data_loader: + pd, pdfs, _ = data + name = os.path.basename(pdfs[4][0]).split('.')[0] + if not last_name == name: + last_name = name + flag+=1 + if flag < 3: + continue + elif flag >= 4: + break + else: + pass + + target = pdfs[1] + + mean = pdfs[2] + std = pdfs[3] + + fname = pdfs[4] + slice_num = pdfs[5] + + mean = mean.unsqueeze(1).unsqueeze(2) + std = std.unsqueeze(1).unsqueeze(2) + + mean = mean.to(device) + std = std.to(device) + + pd_img = pd[1].unsqueeze(1) + pdfs_img = pdfs[0].unsqueeze(1) + + pd_img = pd_img.to(device) + pdfs_img = pdfs_img.to(device) + target = target.to(device) + + outputs = network(pdfs_img, pd_img)['img_out'] + outputs = outputs.squeeze(1) + + outputs_save = outputs[0].cpu().numpy()/6.0 + outputs_save = np.clip(outputs_save, a_min=-1, a_max=1) + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '.png', target[0].cpu().numpy()/6.0) + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '_in.png', pdfs_img[0][0].cpu().numpy()/6.0) + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '_out.png', outputs_save) + + outputs = outputs * std + mean + target = target * std + mean + inputs = pdfs_img.squeeze(1) * std + mean + + output_dic[fname[0]][slice_num[0]] = outputs[0] + target_dic[fname[0]][slice_num[0]] = target[0] + input_dic[fname[0]][slice_num[0]] = inputs[0] + our_nmse = nmse(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + our_psnr = psnr(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + our_ssim = ssim(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + + print('name:{}, slice:{}, psnr:{}, ssim:{}'.format(name, slice_num[0], our_psnr, our_ssim)) + + + for name in output_dic.keys(): + f_output = torch.stack([v for _, v in output_dic[name].items()]) + f_target = torch.stack([v for _, v in target_dic[name].items()]) + our_nmse = nmse(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_psnr = psnr(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_ssim = ssim(f_target.cpu().numpy(), f_output.cpu().numpy()) + + nmse_meter.update(our_nmse, 1) + psnr_meter.update(our_psnr, 1) + ssim_meter.update(our_ssim, 1) + + print("==> Evaluate Metric") + print("Results ----------") + print("NMSE: {:.4}".format(np.array(nmse_meter.score).mean())) + print("PSNR: {:.4}".format(np.array(psnr_meter.score).mean())) + print("SSIM: {:.4}".format(np.array(ssim_meter.score).mean())) + print("NMSE: {:.4}".format(np.array(nmse_meter.score).std())) + print("PSNR: {:.4}".format(np.array(psnr_meter.score).std())) + print("SSIM: {:.4}".format(np.array(ssim_meter.score).std())) + print("------------------") + model.train() + return {'NMSE': nmse_meter.avg, 'PSNR': psnr_meter.avg, 'SSIM':ssim_meter.avg} + +from dataloaders.fastmri import build_dataset +if __name__ == "__main__": + ## make logger file + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + network = TwoBranch(args).cuda() + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + + + db_test = build_dataset(args, mode='val') + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + if args.phase == 'test': + + save_mode_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + print('load weights from ' + save_mode_path) + checkpoint = torch.load(save_mode_path) + + + weights_dict = {} + for k, v in checkpoint['network'].items(): + new_k = k.replace('module.', '') if 'module' in k else k + weights_dict[new_k] = v + # breakpoint() + network.load_state_dict(weights_dict) + network.eval() + + eval_result = evaluate(network, testloader, device, save_path = snapshot_path + '/result_case/') + + + diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/train_brats.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/train_brats.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e3a6c1d7a3d73bb3e5a8347b51f6ea41084c61 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/train_brats.py @@ -0,0 +1,328 @@ +import os +import sys +from tqdm import tqdm +from tensorboardX import SummaryWriter +import shutil +import argparse +import logging +import time +import torch +import numpy as np +import torch.optim as optim +from torchvision import transforms +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision.utils import make_grid +from networks.compare_models import build_model_from_name +from dataloaders.BRATS_dataloader_new import Hybrid as MyDataset +from dataloaders.BRATS_dataloader_new import RandomPadCrop, ToTensor, AddNoise +from networks.mynet import TwoBranch +from option import args +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + + +train_data_path = args.root_path +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +batch_size = args.batch_size * len(args.gpu.split(',')) +max_iterations = args.max_iterations +base_lr = args.base_lr + + + +def cc(img1, img2): + eps = torch.finfo(torch.float32).eps + """Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.].""" + N, C, _, _ = img1.shape + img1 = img1.reshape(N, C, -1) + img2 = img2.reshape(N, C, -1) + img1 = img1 - img1.mean(dim=-1, keepdim=True) + img2 = img2 - img2.mean(dim=-1, keepdim=True) + cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum( + img1 **2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1))) + cc = torch.clamp(cc, -1., 1.) + return cc.mean() + + + +def gradient_calllback(network): + for name, param in network.named_parameters(): + if param.grad is not None: + # print("Gradient of {}: {}".format(name, param.grad.abs().mean())) + + if param.grad.abs().mean() == 0: + print("Gradient of {} is 0".format(name)) + + else: + print("Gradient of {} is None".format(name)) + +class AMPLoss(nn.Module): + def __init__(self): + super(AMPLoss, self).__init__() + self.cri = nn.L1Loss() + + def forward(self, x, y): + x = torch.fft.rfft2(x, norm='backward') + x_mag = torch.abs(x) + y = torch.fft.rfft2(y, norm='backward') + y_mag = torch.abs(y) + + return self.cri(x_mag,y_mag) + + +class PhaLoss(nn.Module): + def __init__(self): + super(PhaLoss, self).__init__() + self.cri = nn.L1Loss() + + def forward(self, x, y): + x = torch.fft.rfft2(x, norm='backward') + x_mag = torch.angle(x) + y = torch.fft.rfft2(y, norm='backward') + y_mag = torch.angle(y) + + return self.cri(x_mag, y_mag) + +if __name__ == "__main__": + ## make logger file + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + network = TwoBranch(args).cuda() + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + # network = nn.SyncBatchNorm.convert_sync_batchnorm(network) + + n_parameters = sum(p.numel() for p in network.parameters() if p.requires_grad) + print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + db_train = MyDataset(split='train', MRIDOWN=args.MRIDOWN, SNR=args.low_field_SNR, + # transform=transforms.Compose([RandomPadCrop(), ToTensor(), AddNoise()]), + transform=transforms.Compose([RandomPadCrop(), ToTensor()]), + base_dir=train_data_path, input_normalize = args.input_normalize) + + db_test = MyDataset(split='test', MRIDOWN=args.MRIDOWN, SNR=args.low_field_SNR, + transform=transforms.Compose([ToTensor()]), + base_dir=test_data_path, input_normalize = args.input_normalize) + + trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + fixtrainloader = DataLoader(db_train, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + if args.phase == 'train': + network.train() + + params = list(network.parameters()) + optimizer1 = optim.AdamW(params, lr=base_lr, betas=(0.9, 0.999), weight_decay=1e-4) + scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=20000, gamma=0.5) + + writer = SummaryWriter(snapshot_path + '/log') + + iter_num = 0 + max_epoch = max_iterations // len(trainloader) + 1 + + + best_status = {'T1_NMSE': 10000000, 'T1_PSNR': 0, 'T1_SSIM': 0, + 'T2_NMSE': 10000000, 'T2_PSNR': 0, 'T2_SSIM': 0} + fft_weight=0.01 + criterion = nn.L1Loss().to(device, non_blocking=True) + amploss = AMPLoss().to(device, non_blocking=True) + phaloss = PhaLoss().to(device, non_blocking=True) + start_time = time.time() + + for epoch_num in tqdm(range(max_epoch), ncols=70): + time1 = time.time() + debug_time = False + + # Data Preparation Time: 0.01880049705505371 + # Network Forward Time: 0.08233189582824707 + # Loss Calculation Time: 0.08654212951660156 + # Optimizer Step Time: 0.4485752582550049 + + for i_batch, (sampled_batch, sample_stats) in enumerate(trainloader): + time2 = time.time() + + t1_in, t1, t2_in, t2 = sampled_batch['image_in'].cuda(), sampled_batch['image'].cuda(), \ + sampled_batch['target_in'].cuda(), sampled_batch['target'].cuda() + + t1_krecon, t2_krecon = sampled_batch['image_krecon'].cuda(), sampled_batch['target_krecon'].cuda() + + + time3 = time.time() + + if debug_time: + print("Data Preparation Time: ", time3 - time2) + print("t1, t2=", t1.shape, t2.shape) + + outputs = network(t2_in, t1_in) + if debug_time: + print("Network Forward Time: ", time.time() - time2) + + loss = criterion(outputs['img_out'], t2) + \ + fft_weight * amploss(outputs['img_fre'], t2) + fft_weight * phaloss( + outputs['img_fre'], + t2) + \ + criterion(outputs['img_fre'], t2) + if debug_time: + print("Loss Calculation Time: ", time.time() - time2) + + time4 = time.time() + optimizer1.zero_grad() + loss.backward() + + if args.clip_grad == "True": + ### clip the gradients to a small range. + torch.nn.utils.clip_grad_norm_(network.parameters(), 0.01) + + optimizer1.step() + scheduler1.step() + if debug_time: + print("Optimizer Step Time: ", time.time() - time2) + + time5 = time.time() + + # summary + iter_num = iter_num + 1 + # writer.add_scalar('lr', scheduler1.get_lr(), iter_num) + # writer.add_scalar('loss/loss', loss, iter_num) + + if iter_num % 100 == 0: + logging.info('iteration %d [%.2f sec]: learning rate : %f loss : %f ' % (iter_num, time.time()-start_time, scheduler1.get_lr()[0], loss.item())) + + if iter_num % 20000 == 0: + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') + torch.save({'network': network.state_dict()}, save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + + if iter_num > max_iterations: + break + time1 = time.time() + + + ## ================ Evaluate ================ + logging.info(f'Epoch {epoch_num} Evaluation:') + # print() + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all = [], [], [] + + t1_MSE_krecon, t1_PSNR_krecon, t1_SSIM_krecon = [], [], [] + t2_MSE_krecon, t2_PSNR_krecon, t2_SSIM_krecon = [], [], [] + + for (sampled_batch, sample_stats) in testloader: + + t1_in, t1, t2_in, t2 = sampled_batch['image_in'].cuda(), sampled_batch['image'].cuda(), \ + sampled_batch['target_in'].cuda(), sampled_batch['target'].cuda() + + t1_krecon, t2_krecon = sampled_batch['image_krecon'].cuda(), sampled_batch['target_krecon'].cuda() + t_merge = torch.cat([t1_in, t2_in], dim=1) + + t2_out = network(t2_in, t1_in)['img_out'] + t1_out = None + + if args.input_normalize == "mean_std": + t1_mean = sample_stats['t1_mean'].data.cpu().numpy()[0] + t1_std = sample_stats['t1_std'].data.cpu().numpy()[0] + t2_mean = sample_stats['t2_mean'].data.cpu().numpy()[0] + t2_std = sample_stats['t2_std'].data.cpu().numpy()[0] + + if t1_out is not None: + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_out_img = (np.clip(t1_out.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_krecon_img = (np.clip(t1_krecon.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + + + t2_img = (np.clip(t2.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_out_img = (np.clip(t2_out.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_krecon_img = (np.clip(t2_krecon.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + + else: + if t1_out is not None: + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t1_out_img = (np.clip(t1_out.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t1_krecon_img = (np.clip(t1_krecon.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + + t2_img = (np.clip(t2.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t2_out_img = (np.clip(t2_out.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t2_krecon_img = (np.clip(t2_krecon.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + + + if t1_out is not None: + + MSE = mean_squared_error(t1_img, t1_out_img) + PSNR = peak_signal_noise_ratio(t1_img, t1_out_img) + SSIM = structural_similarity(t1_img, t1_out_img) + t1_MSE_all.append(MSE) + t1_PSNR_all.append(PSNR) + t1_SSIM_all.append(SSIM) + + MSE = mean_squared_error(t1_img, t1_krecon_img) + PSNR = peak_signal_noise_ratio(t1_img, t1_krecon_img) + SSIM = structural_similarity(t1_img, t1_krecon_img) + t1_MSE_krecon.append(MSE) + t1_PSNR_krecon.append(PSNR) + t1_SSIM_krecon.append(SSIM) + + + if t2_out is not None: + MSE = mean_squared_error(t2_img, t2_out_img) + PSNR = peak_signal_noise_ratio(t2_img, t2_out_img) + SSIM = structural_similarity(t2_img, t2_out_img) + t2_MSE_all.append(MSE) + t2_PSNR_all.append(PSNR) + t2_SSIM_all.append(SSIM) + # print("[t2 MRI] MSE:", MSE, "PSNR:", PSNR, "SSIM:", SSIM) + + MSE = mean_squared_error(t2_img, t2_krecon_img) + PSNR = peak_signal_noise_ratio(t2_img, t2_krecon_img) + SSIM = structural_similarity(t2_img, t2_krecon_img) + t2_MSE_krecon.append(MSE) + t2_PSNR_krecon.append(PSNR) + t2_SSIM_krecon.append(SSIM) + + if t1_out is not None: + t1_mse = np.array(t1_MSE_all).mean() + t1_psnr = np.array(t1_PSNR_all).mean() + t1_ssim = np.array(t1_SSIM_all).mean() + + t1_krecon_mse = np.array(t1_MSE_krecon).mean() + t1_krecon_psnr = np.array(t1_PSNR_krecon).mean() + t1_krecon_ssim = np.array(t1_SSIM_krecon).mean() + + t2_mse = np.array(t2_MSE_all).mean() + t2_psnr = np.array(t2_PSNR_all).mean() + t2_ssim = np.array(t2_SSIM_all).mean() + + t2_krecon_mse = np.array(t2_MSE_krecon).mean() + t2_krecon_psnr = np.array(t2_PSNR_krecon).mean() + t2_krecon_ssim = np.array(t2_SSIM_krecon).mean() + + + if t2_psnr > best_status['T2_PSNR']: + best_status = {'T2_NMSE': t2_mse, 'T2_PSNR': t2_psnr, 'T2_SSIM': t2_ssim} + + best_checkpoint_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + + torch.save({'network': network.state_dict()}, best_checkpoint_path) + print('New Best Network:') + + logging.info(f"[T2 MRI:] average MSE: {t2_mse} average PSNR: {t2_psnr} average SSIM: {t2_ssim}") + + if iter_num > max_iterations: + break + print(best_status) + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(max_iterations) + '.pth') + torch.save({'network': network.state_dict()}, + save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + writer.close() diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/train_fastmri.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/train_fastmri.py new file mode 100644 index 0000000000000000000000000000000000000000..3e268812b01a003271ab21cfa0d7969f1e86ed4d --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/train_fastmri.py @@ -0,0 +1,303 @@ +import os +import sys +from tqdm import tqdm +from tensorboardX import SummaryWriter +import shutil +import argparse +import logging +import time +import torch +import numpy as np +import torch.optim as optim +from torchvision import transforms +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision.utils import make_grid +from networks.compare_models import build_model_from_name +from dataloaders.BRATS_dataloader_new import Hybrid as MyDataset +from dataloaders.BRATS_dataloader_new import RandomPadCrop, ToTensor, AddNoise +from networks.mynet import TwoBranch +from option import args +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity +from dataloaders.fastmri import build_dataset + + +train_data_path = args.root_path +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +batch_size = args.batch_size * len(args.gpu.split(',')) +max_iterations = args.max_iterations +base_lr = args.base_lr + + + +def cc(img1, img2): + eps = torch.finfo(torch.float32).eps + """Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.].""" + N, C, _, _ = img1.shape + img1 = img1.reshape(N, C, -1) + img2 = img2.reshape(N, C, -1) + img1 = img1 - img1.mean(dim=-1, keepdim=True) + img2 = img2 - img2.mean(dim=-1, keepdim=True) + cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum( + img1 **2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1))) + cc = torch.clamp(cc, -1., 1.) + return cc.mean() + + + +def gradient_calllback(network): + for name, param in network.named_parameters(): + if param.grad is not None: + if param.grad.abs().mean() == 0: + print("Gradient of {} is 0".format(name)) + + else: + print("Gradient of {} is None".format(name)) + +class AMPLoss(nn.Module): + def __init__(self): + super(AMPLoss, self).__init__() + self.cri = nn.L1Loss() + + def forward(self, x, y): + x = torch.fft.rfft2(x, norm='backward') + x_mag = torch.abs(x) + y = torch.fft.rfft2(y, norm='backward') + y_mag = torch.abs(y) + + return self.cri(x_mag,y_mag) + + +class PhaLoss(nn.Module): + def __init__(self): + super(PhaLoss, self).__init__() + self.cri = nn.L1Loss() + + def forward(self, x, y): + x = torch.fft.rfft2(x, norm='backward') + x_mag = torch.angle(x) + y = torch.fft.rfft2(y, norm='backward') + y_mag = torch.angle(y) + + return self.cri(x_mag, y_mag) + +from metric import nmse, psnr, ssim, AverageMeter +from collections import defaultdict +@torch.no_grad() +def evaluate(model, data_loader, device): + model.eval() + nmse_meter = AverageMeter() + psnr_meter = AverageMeter() + ssim_meter = AverageMeter() + output_dic = defaultdict(dict) + target_dic = defaultdict(dict) + input_dic = defaultdict(dict) + + for id, data in enumerate(data_loader): + pd, pdfs, _ = data + target = pdfs[1] + + mean = pdfs[2] + std = pdfs[3] + + # print("get mean and std:", mean, std) + + fname = pdfs[4] + slice_num = pdfs[5] + + mean = mean.unsqueeze(1).unsqueeze(2) + std = std.unsqueeze(1).unsqueeze(2) + + mean = mean.to(device) + std = std.to(device) + + + + pd_img = pd[1].unsqueeze(1) + pdfs_img = pdfs[0].unsqueeze(1) + + pd_img = pd_img.to(device) + pdfs_img = pdfs_img.to(device) + target = target.to(device) + + + outputs = model(pdfs_img, pd_img)['img_out'] + outputs = outputs.squeeze(1) + + # print("outputs shape:", outputs.shape, outputs.min(), outputs.max()) + + outputs = outputs * std + mean + target = target * std + mean + inputs = pdfs_img.squeeze(1) * std + mean + + # print("Ourputs after denormalization:", outputs.min(), outputs.max()) + + for i, f in enumerate(fname): + output_dic[f][slice_num[i]] = outputs[i] + target_dic[f][slice_num[i]] = target[i] + input_dic[f][slice_num[i]] = inputs[i] + + if id > 50: + break + + for name in output_dic.keys(): + f_output = torch.stack([v for _, v in output_dic[name].items()]) + f_target = torch.stack([v for _, v in target_dic[name].items()]) + our_nmse = nmse(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_psnr = psnr(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_ssim = ssim(f_target.cpu().numpy(), f_output.cpu().numpy()) + + + nmse_meter.update(our_nmse, 1) + psnr_meter.update(our_psnr, 1) + ssim_meter.update(our_ssim, 1) + + print("==> Evaluate Metric") + print("Results ----------") + print("NMSE: {:.4}".format(nmse_meter.avg)) + print("PSNR: {:.4}".format(psnr_meter.avg)) + print("SSIM: {:.4}".format(ssim_meter.avg)) + print("------------------") + model.train() + + return {'NMSE': nmse_meter.avg, 'PSNR': psnr_meter.avg, 'SSIM':ssim_meter.avg} + + + +if __name__ == "__main__": + ## make logger file + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + network = TwoBranch(args).cuda() + # network = build_model_from_name(args).cuda() + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + # network = nn.SyncBatchNorm.convert_sync_batchnorm(network) + + n_parameters = sum(p.numel() for p in network.parameters() if p.requires_grad) + print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + db_train = build_dataset(args, mode='train') + db_test = build_dataset(args, mode='val') + + trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + if args.phase == 'train': + network.train() + + params = list(network.parameters()) + optimizer1 = optim.AdamW(params, lr=base_lr, betas=(0.9, 0.999), weight_decay=1e-4) + scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=20000, gamma=0.5) + + writer = SummaryWriter(snapshot_path + '/log') + + iter_num = 0 + max_epoch = max_iterations // len(trainloader) + 1 + + + best_status = {'NMSE': 10000000, 'PSNR': 0, 'SSIM': 0} + fft_weight=0.01 + criterion = nn.L1Loss().to(device, non_blocking=True) + amploss = AMPLoss().to(device, non_blocking=True) + phaloss = PhaLoss().to(device, non_blocking=True) + for epoch_num in tqdm(range(max_epoch), ncols=70): + time1 = time.time() + for i_batch, sampled_batch in enumerate(trainloader): + time2 = time.time() + # print("time for data loading:", time2 - time1) + + pd, pdfs, _ = sampled_batch + target = pdfs[1] + + + mean = pdfs[2] + std = pdfs[3] + + + # print("mean:", mean, "std:", std) + + pd_img = pd[1].unsqueeze(1) + pdfs_img = pdfs[0].unsqueeze(1) + target = target.unsqueeze(1) + + pd_img = pd_img.to(device) # [4, 1, 320, 320] + pdfs_img = pdfs_img.to(device) # [4, 1, 320, 320] + target = target.to(device) # [4, 1, 320, 320] + + time3 = time.time() + # breakpoint() + outputs = network(pdfs_img, pd_img) + + loss = criterion(outputs['img_out'], target) + \ + fft_weight * amploss(outputs['img_fre'], target) + fft_weight * phaloss( + outputs['img_fre'], + target) + \ + criterion(outputs['img_fre'], target) + + time4 = time.time() + + + optimizer1.zero_grad() + loss.backward() + + if args.clip_grad == "True": + ### clip the gradients to a small range. + torch.nn.utils.clip_grad_norm_(network.parameters(), 0.01) + + optimizer1.step() + scheduler1.step() + + time5 = time.time() + + # summary + iter_num = iter_num + 1 + + if iter_num % 100 == 0: + logging.info('iteration %d : learning rate : %f loss : %f ' % (iter_num, scheduler1.get_lr()[0], loss.item())) + break + + if iter_num % 20000 == 0: + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') + torch.save({'network': network.state_dict()}, save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + + if iter_num > max_iterations: + break + time1 = time.time() + + + ## ================ Evaluate ================ + logging.info(f'Epoch {epoch_num} Evaluation:') + # print() + eval_result = evaluate(network, testloader, device) + + if eval_result['PSNR'] > best_status['PSNR']: + best_status = {'NMSE': eval_result['NMSE'], 'PSNR': eval_result['PSNR'], 'SSIM': eval_result['SSIM']} + best_checkpoint_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + + torch.save({'network': network.state_dict()}, best_checkpoint_path) + print('New Best Network:') + logging.info(f"average MSE: {eval_result['NMSE']} average PSNR: {eval_result['PSNR']} average SSIM: {eval_result['SSIM']}") + + if iter_num > max_iterations: + break + print(best_status) + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(max_iterations) + '.pth') + torch.save({'network': network.state_dict()}, + save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + writer.close() diff --git a/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/utils.py b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f733ac3d3d6527cae765d48cf58b0c02167532 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/experiments/FSMNet/utils.py @@ -0,0 +1,33 @@ +import numpy as np +import torch + + +def bright(x, a,b): + # input datatype np.uint8 + x = np.array(x, dtype='float') + x = x/(b-a) - 255*a/(b-a) + x[x>255.0] = 255.0 + x[x<0.0] = 0.0 + x = x.astype(np.uint8) + return x + +def trunc(x): + # input datatype float + x[x>255.0] = 255.0 + x[x<0.0] = 0.0 + return x + + + +def cc(img1, img2): + eps = torch.finfo(torch.float32).eps + """Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.].""" + N, C, _, _ = img1.shape + img1 = img1.reshape(N, C, -1) + img2 = img2.reshape(N, C, -1) + img1 = img1 - img1.mean(dim=-1, keepdim=True) + img2 = img2 - img2.mean(dim=-1, keepdim=True) + cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum( + img1 **2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1))) + cc = torch.clamp(cc, -1., 1.) + return cc.mean() \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/main.py b/MRI_recon/code/Frequency-Diffusion/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e98f1a3b2df63dd5a49387208aaca57dc493a0f0 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/main.py @@ -0,0 +1,249 @@ +import torchvision +import os +import errno +import shutil +import argparse +from networks import TwoBranchModel,Unet +from diffusion_pytorch import GaussianDiffusion, Trainer +import torch, warnings + +from pytorch_lightning.callbacks import Callback +warnings.filterwarnings("ignore") + + +class DebugDataloaderCallback(Callback): + # + def __init__(self): + super().__init__() + self.counter = 0 + + def on_train_start(self, trainer, pl_module): + self.counter += 1 + if (self.counter + 1 ) % 10 == 0: + trainer.train_dataloader.dataset.update_chunk() + + + +def create_folder(path): + try: + os.mkdir(path) + except OSError as exc: + if exc.errno != errno.EEXIST: + raise + pass + + +def del_folder(path): + try: + shutil.rmtree(path) + except OSError as exc: + pass + + +create = 0 + +if create: + trainset = torchvision.datasets.CIFAR10( + root='./data', train=True, download=True) + root = './root_cifar10/' + del_folder(root) + create_folder(root) + + for i in range(10): + lable_root = root + str(i) + '/' + create_folder(lable_root) + + for idx in range(len(trainset)): + img, label = trainset[idx] + print(idx) + img.save(root + str(label) + '/' + str(idx) + '.png') + + +parser = argparse.ArgumentParser() +parser.add_argument('--time_steps', default=50, type=int) +parser.add_argument('--train_steps', default=700000, type=int) +parser.add_argument('--save_folder', default=None, type=str) + +parser.add_argument('--load_path', default=None, type=str) +parser.add_argument('--data_path', default='./root_cifar10/', type=str) +parser.add_argument('--fade_routine', default='Random_Incremental', type=str) +parser.add_argument('--sampling_routine', default='x0_step_down', type=str) +parser.add_argument('--discrete', action="store_true") +parser.add_argument('--remove_time_embed', action="store_true") +parser.add_argument('--residual', action="store_true") +parser.add_argument('--tag', default='', type=str) +parser.add_argument('--accelerate_factor', default=4, help="4 | 8", type=int) + + +parser.add_argument('--normalizer', default='mean_std', type=str) + +parser.add_argument('--mode', default='train', type=str) +parser.add_argument('--example_frequency_img', default=None, type=str) +# specific arguments +# parser.add_argument('--initial_mask', default=11, type=int) +parser.add_argument('--kernel_std', default=0.1, type=float) + +parser.add_argument('--dataset', default='brain', type=str) +parser.add_argument('--domain', default=None, type=str) +parser.add_argument('--aux_modality', default=None, type=str) +parser.add_argument('--deviceid', default=0, type=int) +parser.add_argument('--num_channels', default=1, type=int) +parser.add_argument('--train_bs', default=24, type=int) +parser.add_argument('--diffusion_type', default='twobranch_fade', type=str) +parser.add_argument('--debug', action="store_true") +parser.add_argument('--image_size', default=128) +parser.add_argument('--loss_type', default='l1', type=str) + +args = parser.parse_args() +print(args) +os.environ["CUDA_VISIBLE_DEVICES"] = str(args.deviceid) + +image_channels = 1 + +diffusion_type = args.diffusion_type +# diffusion_type = "twobranch_fade" # model_degradation # fade | kspace +model_name = diffusion_type.split("_")[0] # unet | twobranch + +save_and_sample_every = 1000 + +if args.debug: + args.train_steps = 100 + args.time_steps = 5 + +model = None + + +if isinstance(args.image_size, str): + length = len(args.image_size.split(",")) + if length == 1: + args.image_size = (int(args.image_size), int(args.image_size)) + elif length == 2: + args.image_size = (int(args.image_size.split(",")[0]), int(args.image_size.split(",")[1])) +else: + args.image_size = (args.image_size, args.image_size) + + + +if model_name == "unet": + model = Unet(resolution=args.image_size[0], + in_channels=1, + out_ch=1, + ch=128, + ch_mult=(1, 2, 2, 2), + num_res_blocks=2, + attn_resolutions=(16,), + dropout=0.1).cuda() + +elif model_name == "twounet": + model = TwoBranchNewModel(resolution=args.image_size[0], + in_channels=1, + out_ch=1, + ch=128, + ch_mult=(1, 2, 2, 2), + num_res_blocks=3, + attn_resolutions=(16,), + dropout=0.1).cuda() # Drop out used to be 0.1 + + +elif model_name == "twobranch": + + base_num_every_group = 2 + num_features = 64 + act = "PReLU" + num_channels = 1 + + from networks.networks_fsm.mynet import TwoBranch as TwoBranchModel + + + model = TwoBranchModel( + num_features, act, base_num_every_group, num_channels + ).cuda() + +fp16 = False + + +n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) +print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + +diffusion = GaussianDiffusion( + diffusion_type, + model, + image_size=args.image_size[0], # Used to be 32 + channels=image_channels, + device_of_kernel='cuda', + timesteps=args.time_steps, + loss_type=args.loss_type, #$'l1', + kernel_std=args.kernel_std, + fade_routine=args.fade_routine, + sampling_routine=args.sampling_routine, + discrete=args.discrete, + accelerate_factor=args.accelerate_factor, + fp16=fp16, + normalizer=args.normalizer, + example_frequency_img=args.example_frequency_img, +).cuda() + + +diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) + +print("=== train_steps:", args.train_steps) +os.makedirs(args.save_folder, exist_ok=True) + +if args.debug: + args.save_folder = args.save_folder + "_debug" +else: + args.save_folder = args.save_folder + f"_{args.tag}" + save_and_sample_every = 500 + + +# if os.path.exists(args.save_folder): +name = args.save_folder.split("/")[-1] +number = os.listdir(args.save_folder.rstrip(name)).__len__() +if args.mode == "test": + number = "test_" + str(number) + +args.save_folder = os.path.join(args.save_folder.rstrip(name), f"{number}_" + name) + +# create the folder and parent folders +os.makedirs(args.save_folder, exist_ok=True) + + + +print("SAVE FOLDER: ", args.save_folder) + +trainer = Trainer( + diffusion, + args.data_path, + mode = args.mode, + norm = args.normalizer, + image_size=args.image_size, # Used to be 32 + train_batch_size=args.train_bs, + train_lr= 1e-4, # 2e-5 + train_num_steps=args.train_steps, + gradient_accumulate_every=1, + ema_decay=0.995, + save_and_sample_every=save_and_sample_every, + fp16=fp16, + results_folder=args.save_folder, + load_path=args.load_path, + dataset=args.dataset, + domain=args.domain, + aux_modality=args.aux_modality, + debug=args.debug, + num_channels=args.num_channels + # accelerator="gpu", + # callbacks=[DebugDataloaderCallback()], +) + + +if args.mode == "train": + trainer.train() + +elif args.mode == "test": + # ['default', 'x0_step_down', 'x0_step_down_fre', "fre_progressive"]: + trainer.test_loader('x0_step_down_fre') + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/metrics/fid.py b/MRI_recon/code/Frequency-Diffusion/metrics/fid.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa3918b97302d209d4fb2f88dc50ca7ef1476b5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/metrics/fid.py @@ -0,0 +1,334 @@ +import torch +from torch import nn +from torchvision.models import inception_v3 +import cv2 +import multiprocessing +import numpy as np +import glob +import os +from scipy import linalg + + +def to_cuda(elements): + """ + Transfers elements to cuda if GPU is available + Args: + elements: torch.tensor or torch.nn.module + -- + Returns: + elements: same as input on GPU memory, if available + """ + if torch.cuda.is_available(): + return elements.cuda() + return elements + + +class PartialInceptionNetwork(nn.Module): + + def __init__(self, transform_input=True): + super().__init__() + self.inception_network = inception_v3(pretrained=True) + self.inception_network.Mixed_7c.register_forward_hook(self.output_hook) + self.transform_input = transform_input + + def output_hook(self, module, input, output): + # N x 2048 x 8 x 8 + self.mixed_7c_output = output + + def forward(self, x): + """ + Args: + x: shape (N, 3, 299, 299) dtype: torch.float32 in range 0-1 + Returns: + inception activations: torch.tensor, shape: (N, 2048), dtype: torch.float32 + """ + assert x.shape[1:] == (3, 299, 299), "Expected input shape to be: (N,3,299,299)" + \ + ", but got {}".format(x.shape) + x = x * 2 - 1 # Normalize to [-1, 1] + + # Trigger output hook + self.inception_network(x) + + # Output: N x 2048 x 1 x 1 + activations = self.mixed_7c_output + activations = torch.nn.functional.adaptive_avg_pool2d(activations, (1, 1)) + activations = activations.view(x.shape[0], 2048) + return activations + +class PartialResnet3D(nn.Module): + + def __init__(self, transform_input=True): + super().__init__() + model = torch.hub.load('facebookresearch/pytorchvideo', 'slow_r50', + pretrained=True) + + model.blocks[5].proj = nn.Identity() + model.blocks[5].output_pool = nn.Identity() + + self.network = model + + # input = torch.ones(1, 3, 8, 256, 256) + + + # self.inception_network.Mixed_7c.register_forward_hook(self.output_hook) + self.transform_input = transform_input + + # def output_hook(self, module, input, output): + # N x 2048 x 8 x 8 + # self.mixed_7c_output = output + + def forward(self, x): + """ + Args: + x: shape (N, 3, 299, 299) dtype: torch.float32 in range 0-1 + Returns: + inception activations: torch.tensor, shape: (N, 2048), dtype: torch.float32 + """ + # assert x.shape[1:] == (3, 256, 256), "Expected input shape to be: (N,3,299,299)" + \ + # ", but got {}".format(x.shape) + + x = x * 2 - 1 # Normalize to [-1, 1] + + # Trigger output hook + activations = self.inception_network(x) + + + # Output: N x 2048 x 1 x 1 + # activations = self.mixed_7c_output + activations = torch.nn.functional.adaptive_avg_pool2d(activations, (1, 1)) + activations = activations.view(x.shape[0], 2048) + return activations + + + +def get_activations(images, batch_size): + """ + Calculates activations for last pool layer for all iamges + -- + Images: torch.array shape: (N, 3, 299, 299), dtype: torch.float32 + batch size: batch size used for inception network + -- + Returns: np array shape: (N, 2048), dtype: np.float32 + """ + assert images.shape[1:] == (3, 299, 299), "Expected input shape to be: (N,3,299,299)" + \ + ", but got {}".format(images.shape) + + num_images = images.shape[0] + inception_network = PartialInceptionNetwork() + inception_network = to_cuda(inception_network) + inception_network.eval() + n_batches = int(np.ceil(num_images / batch_size)) + inception_activations = np.zeros((num_images, 2048), dtype=np.float32) + for batch_idx in range(n_batches): + start_idx = batch_size * batch_idx + end_idx = batch_size * (batch_idx + 1) + + ims = images[start_idx:end_idx] + ims = to_cuda(ims) + activations = inception_network(ims) + activations = activations.detach().cpu().numpy() + assert activations.shape == (ims.shape[0], 2048), "Expexted output shape to be: {}, but was: {}".format( + (ims.shape[0], 2048), activations.shape) + inception_activations[start_idx:end_idx, :] = activations + return inception_activations + + +def calculate_activation_statistics(images, batch_size): + """Calculates the statistics used by FID + Args: + images: torch.tensor, shape: (N, 3, H, W), dtype: torch.float32 in range 0 - 1 + batch_size: batch size to use to calculate inception scores + Returns: + mu: mean over all activations from the last pool layer of the inception model + sigma: covariance matrix over all activations from the last pool layer + of the inception model. + + """ + act = get_activations(images, batch_size) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +# Modified from: https://github.com/bioinf-jku/TTUR/blob/master/fid.py +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of the pool_3 layer of the + inception net ( like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted + on an representive data set. + -- sigma1: The covariance matrix over activations of the pool_3 layer for + generated samples. + -- sigma2: The covariance matrix over activations of the pool_3 layer, + precalcualted on an representive data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" + assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + # product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps + warnings.warn(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + +def preprocess_image(im): + """Resizes and shifts the dynamic range of image to 0-1 + Args: + im: np.array, shape: (H, W, 3), dtype: float32 between 0-1 or np.uint8 + Return: + im: torch.tensor, shape: (3, 299, 299), dtype: torch.float32 between 0-1 + """ + # print("im shape:", im.shape) + if im.shape[0] == 3: + im = im.transpose(1, 2, 0) + # CHW->HWC + + # print("new im shape:", im.shape) + + + assert im.shape[2] == 3 + assert len(im.shape) == 3 + if im.dtype == np.uint8: + im = im.astype(np.float32) / 255 + + im = cv2.resize(im, (299, 299)) + im = np.rollaxis(im, axis=2) + im = torch.from_numpy(im) + assert im.max() <= 1.0 + assert im.min() >= 0.0 + assert im.dtype == torch.float32 + assert im.shape == (3, 299, 299) + + return im + + +def preprocess_images(images, use_multiprocessing): + """Resizes and shifts the dynamic range of image to 0-1 + Args: + images: np.array, shape: (N, H, W, 3), dtype: float32 between 0-1 or np.uint8 + use_multiprocessing: If multiprocessing should be used to pre-process the images + Return: + final_images: torch.tensor, shape: (N, 3, 299, 299), dtype: torch.float32 between 0-1 + """ + if use_multiprocessing: + with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: + jobs = [] + for im in images: + job = pool.apply_async(preprocess_image, (im,)) + jobs.append(job) + final_images = torch.zeros(images.shape[0], 3, 299, 299) + for idx, job in enumerate(jobs): + im = job.get() + final_images[idx] = im # job.get() + else: + final_images = torch.stack([preprocess_image(im) for im in images], dim=0) + + assert final_images.shape == (images.shape[0], 3, 299, 299) + assert final_images.max() <= 1.0 + assert final_images.min() >= 0.0 + assert final_images.dtype == torch.float32 + return final_images + + +def calculate_fid(images1, images2, use_multiprocessing, batch_size): + """ Calculate FID between images1 and images2 + Args: + images1: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8 + images2: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8 + use_multiprocessing: If multiprocessing should be used to pre-process the images + batch size: batch size used for inception network + Returns: + FID (scalar) + """ + images1 = preprocess_images(images1, use_multiprocessing) + images2 = preprocess_images(images2, use_multiprocessing) + mu1, sigma1 = calculate_activation_statistics(images1, batch_size) + mu2, sigma2 = calculate_activation_statistics(images2, batch_size) + fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2) + return fid + + +def load_images(path): + """ Loads all .png or .jpg images from a given path + Warnings: Expects all images to be of same dtype and shape. + Args: + path: relative path to directory + Returns: + final_images: np.array of image dtype and shape. + """ + image_paths = [] + image_extensions = ["png", "jpg"] + for ext in image_extensions: + print("Looking for images in", os.path.join(path, "*.{}".format(ext))) + for impath in glob.glob(os.path.join(path, "*.{}".format(ext))): + image_paths.append(impath) + first_image = cv2.imread(image_paths[0]) + W, H = first_image.shape[:2] + image_paths.sort() + image_paths = image_paths + final_images = np.zeros((len(image_paths), H, W, 3), dtype=first_image.dtype) + for idx, impath in enumerate(image_paths): + im = cv2.imread(impath) + im = im[:, :, ::-1] # Convert from BGR to RGB + assert im.dtype == final_images.dtype + final_images[idx] = im + return final_images + + +if __name__ == "__main__": + from optparse import OptionParser + + parser = OptionParser() + parser.add_option("--p1", "--path1", dest="path1", + help="Path to directory containing the real images") + parser.add_option("--p2", "--path2", dest="path2", + help="Path to directory containing the generated images") + parser.add_option("--multiprocessing", dest="use_multiprocessing", + help="Toggle use of multiprocessing for image pre-processing. Defaults to use all cores", + default=False, + action="store_true") + parser.add_option("-b", "--batch-size", dest="batch_size", + help="Set batch size to use for InceptionV3 network", + type=int) + + options, _ = parser.parse_args() + assert options.path1 is not None, "--path1 is an required option" + assert options.path2 is not None, "--path2 is an required option" + assert options.batch_size is not None, "--batch_size is an required option" + images1 = load_images(options.path1) + images2 = load_images(options.path2) + fid_value = calculate_fid(images1, images2, options.use_multiprocessing, options.batch_size) + print(fid_value) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/metrics/fid_3d.py b/MRI_recon/code/Frequency-Diffusion/metrics/fid_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..054c7ed99d5aa823814bed42c4de921fb8cdf56b --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/metrics/fid_3d.py @@ -0,0 +1,307 @@ +import torch +from torch import nn +from torchvision.models import inception_v3 +import cv2 +import multiprocessing +import numpy as np +import glob +import os +from scipy import linalg + + +def to_cuda(elements): + """ + Transfers elements to cuda if GPU is available + Args: + elements: torch.tensor or torch.nn.module + -- + Returns: + elements: same as input on GPU memory, if available + """ + if torch.cuda.is_available(): + return elements.cuda() + return elements + + + + +class PartialResnet3D(nn.Module): + + def __init__(self, transform_input=True): + super().__init__() + model = torch.hub.load('facebookresearch/pytorchvideo', 'slow_r50', + pretrained=True) + + model.blocks[5].proj = nn.Identity() + model.blocks[5].output_pool = nn.Identity() + + self.network = model + + # input = torch.ones(1, 3, 8, 256, 256) + + # self.inception_network.Mixed_7c.register_forward_hook(self.output_hook) + self.transform_input = transform_input + + # def output_hook(self, module, input, output): + # N x 98304 x 8 x 8 + # self.mixed_7c_output = output + + def forward(self, x): + """ + Args: + x: shape (N, 3, 299, 299) dtype: torch.float32 in range 0-1 + Returns: + inception activations: torch.tensor, shape: (N, 98304), dtype: torch.float32 + """ + # assert x.shape[1:] == (3, 256, 256), "Expected input shape to be: (N,3,299,299)" + \ + # ", but got {}".format(x.shape) + + x = x * 2 - 1 # Normalize to [-1, 1] + + # Trigger output hook + activations = self.network(x) + # print("activations shape:", activations.shape) # activations shape: torch.Size([1, 98304]) + + # Output: N x 98304 x 1 x 1 + # activations = self.mixed_7c_output + # activations = torch.nn.functional.adaptive_avg_pool2d(activations, (1, 1)) + activations = activations.view(x.shape[0], 98304) + return activations + + +def get_activations(images, batch_size): + """ + Calculates activations for last pool layer for all iamges + -- + Images: torch.array shape: (N, 3, 299, 299), dtype: torch.float32 + batch size: batch size used for inception network + -- + Returns: np array shape: (N, 98304), dtype: np.float32 + """ + # assert images.shape[1:] == (3, 299, 299), "Expected input shape to be: (N,3,299,299)" + \ + # ", but got {}".format(images.shape) + + num_images = images.shape[0] + inception_network = PartialResnet3D() + inception_network = to_cuda(inception_network) + inception_network.eval() + n_batches = int(np.ceil(num_images / batch_size)) + inception_activations = np.zeros((num_images, 98304), dtype=np.float32) + for batch_idx in range(n_batches): + start_idx = batch_size * batch_idx + end_idx = batch_size * (batch_idx + 1) + + ims = images[start_idx:end_idx] + ims = to_cuda(ims) + activations = inception_network(ims) + activations = activations.detach().cpu().numpy() + assert activations.shape == (ims.shape[0], 98304), "Expexted output shape to be: {}, but was: {}".format( + (ims.shape[0], 98304), activations.shape) + inception_activations[start_idx:end_idx, :] = activations + return inception_activations + + +def calculate_activation_statistics(images, batch_size): + """Calculates the statistics used by FID + Args: + images: torch.tensor, shape: (N, 3, H, W), dtype: torch.float32 in range 0 - 1 + batch_size: batch size to use to calculate inception scores + Returns: + mu: mean over all activations from the last pool layer of the inception model + sigma: covariance matrix over all activations from the last pool layer + of the inception model. + + """ + act = get_activations(images, batch_size) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +# Modified from: https://github.com/bioinf-jku/TTUR/blob/master/fid.py +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of the pool_3 layer of the + inception net ( like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted + on an representive data set. + -- sigma1: The covariance matrix over activations of the pool_3 layer for + generated samples. + -- sigma2: The covariance matrix over activations of the pool_3 layer, + precalcualted on an representive data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" + assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + # product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps + warnings.warn(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + +def preprocess_image(im): + """Resizes and shifts the dynamic range of image to 0-1 + Args: + im: np.array, shape: (H, W, 3), dtype: float32 between 0-1 or np.uint8 + Return: + im: torch.tensor, shape: (3, 299, 299), dtype: torch.float32 between 0-1 + """ + # print("im shape:", im.shape) + if im.shape[0] == 3: + im = im.transpose(1, 2, 0) + # CHW->HWC + + # print("new im shape:", im.shape) + + assert im.shape[2] == 3 + assert len(im.shape) == 3 + if im.dtype == np.uint8: + im = im.astype(np.float32) / 255 + + im = cv2.resize(im, (299, 299)) + im = np.rollaxis(im, axis=2) + im = torch.from_numpy(im) + assert im.max() <= 1.0 + assert im.min() >= 0.0 + assert im.dtype == torch.float32 + assert im.shape == (3, 299, 299) + + return im + + +def preprocess_images(images, use_multiprocessing): + """Resizes and shifts the dynamic range of image to 0-1 + Args: + images: np.array, shape: (N, H, W, 3), dtype: float32 between 0-1 or np.uint8 + use_multiprocessing: If multiprocessing should be used to pre-process the images + Return: + final_images: torch.tensor, shape: (N, 3, 299, 299), dtype: torch.float32 between 0-1 + """ + if use_multiprocessing: + with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: + jobs = [] + for im in images: + job = pool.apply_async(preprocess_image, (im,)) + jobs.append(job) + final_images = torch.zeros(images.shape[0], 3, 256, 256) + for idx, job in enumerate(jobs): + im = job.get() + final_images[idx] = im # job.get() + else: + final_images = torch.stack([preprocess_image(im) for im in images], dim=0) + + # print("final_images shape:", final_images.shape) + # assert final_images.shape == (1, 3, images.shape[0], 256, 256) + assert final_images.max() <= 1.0 + assert final_images.min() >= 0.0 + assert final_images.dtype == torch.float32 + return final_images + + +def calculate_fid_3d(images1, images2, use_multiprocessing, batch_size): + """ Calculate FID between images1 and images2 + Args: + images1: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8 + images2: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8 + use_multiprocessing: If multiprocessing should be used to pre-process the images + batch size: batch size used for inception network + Returns: + FID (scalar) + """ + images1 = preprocess_images(images1, use_multiprocessing) + images2 = preprocess_images(images2, use_multiprocessing) + + # C, 3, H, W -> 1, 3, C, H, W + images1 = images1.unsqueeze(0).permute(0, 2, 1, 3, 4) + images2 = images2.unsqueeze(0).permute(0, 2, 1, 3, 4) + + mu1, sigma1 = calculate_activation_statistics(images1, batch_size) + mu2, sigma2 = calculate_activation_statistics(images2, batch_size) + fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2) + return fid + + +def load_images(path): + """ Loads all .png or .jpg images from a given path + Warnings: Expects all images to be of same dtype and shape. + Args: + path: relative path to directory + Returns: + final_images: np.array of image dtype and shape. + """ + image_paths = [] + image_extensions = ["png", "jpg"] + for ext in image_extensions: + print("Looking for images in", os.path.join(path, "*.{}".format(ext))) + for impath in glob.glob(os.path.join(path, "*.{}".format(ext))): + image_paths.append(impath) + first_image = cv2.imread(image_paths[0]) + W, H = first_image.shape[:2] + image_paths.sort() + image_paths = image_paths + final_images = np.zeros((len(image_paths), H, W, 3), dtype=first_image.dtype) + for idx, impath in enumerate(image_paths): + im = cv2.imread(impath) + im = im[:, :, ::-1] # Convert from BGR to RGB + assert im.dtype == final_images.dtype + final_images[idx] = im + return final_images + + +if __name__ == "__main__": + from optparse import OptionParser + + parser = OptionParser() + parser.add_option("--p1", "--path1", dest="path1", + help="Path to directory containing the real images") + parser.add_option("--p2", "--path2", dest="path2", + help="Path to directory containing the generated images") + parser.add_option("--multiprocessing", dest="use_multiprocessing", + help="Toggle use of multiprocessing for image pre-processing. Defaults to use all cores", + default=False, + action="store_true") + parser.add_option("-b", "--batch-size", dest="batch_size", + help="Set batch size to use for InceptionV3 network", + type=int) + + options, _ = parser.parse_args() + assert options.path1 is not None, "--path1 is an required option" + assert options.path2 is not None, "--path2 is an required option" + assert options.batch_size is not None, "--batch_size is an required option" + images1 = load_images(options.path1) + images2 = load_images(options.path2) + fid_value = calculate_fid(images1, images2, options.use_multiprocessing, options.batch_size) + print(fid_value) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/metrics/frequency_loss.py b/MRI_recon/code/Frequency-Diffusion/metrics/frequency_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..642abd11d2dc3af909744b6fff3ac8463bf0f5ad --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/metrics/frequency_loss.py @@ -0,0 +1,43 @@ +import torch + + + +class AMPLoss(torch.nn.Module): + def __init__(self, epsilon=1e-8, loss="l1", norm='backward'): + super(AMPLoss, self).__init__() + self.mask_region = False # TODO + + if loss == "l1": + self.cri = torch.nn.L1Loss(reduction="sum" if self.mask_region else "mean") + else: + self.cri = torch.nn.MSELoss(reduction="sum" if self.mask_region else "mean") + + self.epsilon = epsilon # To prevent division by zero + self.norm = norm # Normalization for FFT + + def forward(self, x, y, k): + # Perform FFT and compute magnitudes + x_fft = torch.fft.rfft2(x, norm=self.norm) + y_fft = torch.fft.rfft2(y, norm=self.norm) + + x_mag = torch.clamp(torch.abs(x_fft), min=self.epsilon) # Clamp to avoid zeros + y_mag = torch.clamp(torch.abs(y_fft), min=self.epsilon) # Clamp to avoid zeros + + x_phase = torch.angle(x_fft) + y_phase = torch.angle(y_fft) + + if self.mask_region: + W = x.shape[-1] + k = (1 - k.to(x.device)) + k = k[..., :W // 2 + 1] + k_total = torch.sum(k) + + x_mag = x_mag * k + y_mag = y_mag * k + x_phase = x_phase * k + y_phase = y_phase * k + # Compute L1 loss between magnitudes + return self.cri(x_mag, y_mag)/k_total + self.cri(x_phase, y_phase)/k_total + else: + # Compute L1 loss between magnitudes + return self.cri(x_mag, y_mag) + self.cri(x_phase, y_phase) diff --git a/MRI_recon/code/Frequency-Diffusion/metrics/lpips.py b/MRI_recon/code/Frequency-Diffusion/metrics/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a30b875fb4aa39ccd8419759d2f841d62bbad6 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/metrics/lpips.py @@ -0,0 +1,184 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" + +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + + +from collections import namedtuple +from torchvision import models +import torch.nn as nn +import torch +from tqdm import tqdm +import requests +import os +import hashlib +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format( + name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + self.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name is not "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + model.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + input = input.float() + target = target.float() + + in0_input, in1_input = (self.scaling_layer( + input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor( + outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor( + [-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor( + [.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, + padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, + h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/MRI_recon/code/Frequency-Diffusion/metrics/nmse.py b/MRI_recon/code/Frequency-Diffusion/metrics/nmse.py new file mode 100644 index 0000000000000000000000000000000000000000..790122086edaf81b7c4e268adacb1fff6b0ce3a9 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/metrics/nmse.py @@ -0,0 +1,5 @@ +import numpy as np + +def nmse(gt, pred): + """Compute Normalized Mean Squared Error (NMSE)""" + return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2 \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/networks/Unet.py b/MRI_recon/code/Frequency-Diffusion/networks/Unet.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c154dd8c2912e0665589496ce1154080ecc1a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/Unet.py @@ -0,0 +1,332 @@ +import math +import torch +import torch.nn as nn + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + + + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t): + assert x.shape[2] == x.shape[3] == self.resolution + + # timestep embedding + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + + # print(t) + # print(temb) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h diff --git a/MRI_recon/code/Frequency-Diffusion/networks/__init__.py b/MRI_recon/code/Frequency-Diffusion/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f42d964b4b0e18ce8995b50019d171bd2340ad --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/__init__.py @@ -0,0 +1,2 @@ +from .Unet import Model as Unet +from .st_branch_model.model import TwoBranchModel \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/ARTUnet.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/ARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ae23fed16b0d9454a5ea09f17b4322f1e9d7e928 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/ARTUnet.py @@ -0,0 +1,751 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + # print("x shape:", x.shape) + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +########################################################################## +class CS_TransformerBlock(nn.Module): + r""" 在ART Transformer Block的基础上添加了channel attention的分支。 + 参考论文: Dual Aggregation Transformer for Image Super-Resolution + 及其代码: https://github.com/zhengchen1999/DAT/blob/main/basicsr/archs/dat_arch.py + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + + self.dwconv = nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim), + nn.InstanceNorm2d(dim), + nn.GELU()) + + self.channel_interaction = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(dim, dim // 8, kernel_size=1), + # nn.InstanceNorm2d(dim // 8), + nn.GELU(), + nn.Conv2d(dim // 8, dim, kernel_size=1)) + + self.spatial_interaction = nn.Sequential( + nn.Conv2d(dim, dim // 16, kernel_size=1), + nn.InstanceNorm2d(dim // 16), + nn.GELU(), + nn.Conv2d(dim // 16, 1, kernel_size=1)) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # print("x before dwconv:", x.shape) + # convolution output + conv_x = self.dwconv(x.permute(0, 3, 1, 2)) + # conv_x = x.permute(0, 3, 1, 2) + # print("x after dwconv:", x.shape) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + attened_x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + attened_x = attened_x[:, :H, :W, :].contiguous() + + attened_x = attened_x.view(B, H * W, C) + + # Adaptive Interaction Module (AIM) + # C-Map (before sigmoid) + channel_map = self.channel_interaction(conv_x).permute(0, 2, 3, 1).contiguous().view(B, 1, C) + # S-Map (before sigmoid) + attention_reshape = attened_x.transpose(-2,-1).contiguous().view(B, C, H, W) + spatial_map = self.spatial_interaction(attention_reshape) + + # C-I + attened_x = attened_x * torch.sigmoid(channel_map) + # S-I + conv_x = torch.sigmoid(spatial_map) * conv_x + conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, H * W, C) + + x = attened_x + conv_x + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/ARTUnet_new.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/ARTUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9d210d46e2a4d73781b3b16b63a53fdc0f25b9 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/ARTUnet_new.py @@ -0,0 +1,823 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +Unet的格式构建image restoration 网络。每个transformer block中都是dense self-attention和sparse self-attention交替连接。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True, visualize_attention_maps=False): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + self.visualize_attention_maps = visualize_attention_maps + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + # assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + # qkv: 3, B_, num_heads, N, C // num_heads + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + # B_, num_heads, N, C // num_heads * + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + attn_map = attn.detach().cpu().numpy().mean(axis=1) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) # (B * nP, nHead, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + # attn: B * nP, num_heads, N, N + # v: B * nP, num_heads, N, C // num_heads + # -attn-@-v-> B * nP, num_heads, N, C // num_heads + # -transpose-> B * nP, N, num_heads, C // num_heads + # -reshape-> B * nP, N, C + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, G ** 2, G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, Gh * Gw, Gh * Gw) + else: + x = self.attn(x, Gh, Gw, mask=attn_mask) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +class ConcatTransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + ################## + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + ################## + + x = torch.cat((x, y), dim=1) + + shortcut = x + x = self.norm1(x) + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, 2 * Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, 2 * G ** 2, 2 * G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, 2 * Gh * Gw, 2 * Gh * Gw) + else: + x = self.attn(x, 2 * Gh, Gw, mask=attn_mask) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + output = x + + # print(x.shape, Gh, Gw) + x = output[:, :Gh * Gw, :] + y = output[:, Gh * Gw:, :] + assert x.shape == y.shape + + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = y.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = y.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, y, attn_map + else: + return x, y + + def get_patches(self, x, x_size): + H, W = x_size + B, H, W, C = x.shape + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + return x, attn_mask, (Hd, Wd), (Gh, Gw), (pad_r, pad_b), G, I + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/ART_Restormer.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/ART_Restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..4c23123f0ac1a8e82e2bf907658ce8d91951c3e4 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/ART_Restormer.py @@ -0,0 +1,592 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer串联而成。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[4, 4, 4, 4], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + + ######################################### Encoder level1 ############################################ + ### ART encoder block + for layer in self.ART_encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + + ### Restormer encoder block + out_enc_level1 = rearrange(out_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = self.Restormer_encoder_level1(out_enc_level1) + # stack.append(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.ART_encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + out_enc_level2 = rearrange(out_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = self.Restormer_encoder_level2(out_enc_level2) + # stack.append(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.ART_encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + out_enc_level3 = rearrange(out_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = self.Restormer_encoder_level3(out_enc_level3) + # stack.append(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.ART_latent: + latent = layer(latent, [H // 8, W // 8]) + + ### Restormer encoder block + latent = rearrange(latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = self.Restormer_latent(latent) + # stack.append(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + out_dec_level3 = inp_dec_level3 + for layer in self.ART_decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + out_dec_level3 = rearrange(out_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + out_dec_level3 = self.Restormer_decoder_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + out_dec_level2 = inp_dec_level2 + for layer in self.ART_decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + out_dec_level2 = rearrange(out_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + out_dec_level2 = self.Restormer_decoder_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + out_dec_level1 = inp_dec_level1 + for layer in self.ART_decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + ### Restormer decoder block + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_decoder_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_refinement(out_dec_level1) + + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) + + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/ART_Restormer_v2.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/ART_Restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..de8f921e81d79cb7af14a75326f0ddaaf3b3f4c3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/ART_Restormer_v2.py @@ -0,0 +1,663 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer并联而成。 +输入特征分别经过ART和Restormer两个分支提取特征, 将两者的特征concat之后,用conv层变换得到该block的输出特征。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.enc_fuse_level1 = nn.Conv2d(int(dim * 2 ** 1), dim, kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.enc_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.enc_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + self.enc_fuse_level4 = nn.Conv2d(int(dim * 2 ** 4), int(dim * 2 ** 3), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.dec_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.dec_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.dec_fuse_level1 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + # self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + # bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + # self.fuse_refinement = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + # def forward(self, inp_img, aux_img): + # stack = [] + # bs, _, H, W = inp_img.shape + + # fuse_img = torch.cat((inp_img, aux_img), 1) + # inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + + def forward(self, inp_img): + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + + ######################################### Encoder level1 ############################################ + art_enc_level1 = inp_enc_level1 + restor_enc_level1 = inp_enc_level1 + + ### ART encoder block + for layer in self.ART_encoder_level1: + art_enc_level1 = layer(art_enc_level1, [H, W]) + + ### Restormer encoder block + restor_enc_level1 = rearrange(restor_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_enc_level1 = self.Restormer_encoder_level1(restor_enc_level1) ### (b, c, h, w) + + art_enc_level1 = rearrange(art_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = torch.cat((art_enc_level1, restor_enc_level1), 1) + out_enc_level1 = self.enc_fuse_level1(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + art_enc_level2 = inp_enc_level2 + restor_enc_level2 = inp_enc_level2 + + for layer in self.ART_encoder_level2: + art_enc_level2 = layer(art_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + restor_enc_level2 = rearrange(restor_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + restor_enc_level2 = self.Restormer_encoder_level2(restor_enc_level2) + + art_enc_level2 = rearrange(art_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = torch.cat((art_enc_level2, restor_enc_level2), 1) + out_enc_level2 = self.enc_fuse_level2(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + art_enc_level3 = inp_enc_level3 + restor_enc_level3 = inp_enc_level3 + + for layer in self.ART_encoder_level3: + art_enc_level3 = layer(art_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + restor_enc_level3 = rearrange(restor_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + restor_enc_level3 = self.Restormer_encoder_level3(restor_enc_level3) + + art_enc_level3 = rearrange(art_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = torch.cat((art_enc_level3, restor_enc_level3), 1) + out_enc_level3 = self.enc_fuse_level3(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + art_latent = inp_enc_level4 + restor_latent = inp_enc_level4 + + for layer in self.ART_latent: + art_latent = layer(art_latent, [H // 8, W // 8]) + + ### Restormer encoder block + restor_latent = rearrange(restor_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + restor_latent = self.Restormer_latent(restor_latent) + + art_latent = rearrange(art_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = torch.cat((art_latent, restor_latent), 1) + latent = self.enc_fuse_level4(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + art_dec_level3 = inp_dec_level3 + restor_dec_level3 = inp_dec_level3 + + for layer in self.ART_decoder_level3: + art_dec_level3 = layer(art_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + restor_dec_level3 = rearrange(restor_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + restor_dec_level3 = self.Restormer_decoder_level3(restor_dec_level3) + + art_dec_level3 = rearrange(art_dec_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_dec_level3 = torch.cat((art_dec_level3, restor_dec_level3), 1) + out_dec_level3 = self.dec_fuse_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + art_dec_level2 = inp_dec_level2 + restor_dec_level2 = inp_dec_level2 + + for layer in self.ART_decoder_level2: + art_dec_level2 = layer(art_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + restor_dec_level2 = rearrange(restor_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + restor_dec_level2 = self.Restormer_decoder_level2(restor_dec_level2) + + art_dec_level2 = rearrange(art_dec_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_dec_level2 = torch.cat((art_dec_level2, restor_dec_level2), 1) + out_dec_level2 = self.dec_fuse_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + art_dec_level1 = inp_dec_level1 + restor_dec_level1 = inp_dec_level1 + + for layer in self.ART_decoder_level1: + art_dec_level1 = layer(art_dec_level1, [H, W]) + + ### Restormer decoder block + restor_dec_level1 = rearrange(restor_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_dec_level1 = self.Restormer_decoder_level1(restor_dec_level1) + + art_dec_level1 = rearrange(art_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = torch.cat((art_dec_level1, restor_dec_level1), 1) + out_dec_level1 = self.dec_fuse_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +# def build_model(args): +# return ART_Restormer( +# inp_channels=2, +# out_channels=1, +# dim=32, +# num_art_blocks=[2, 4, 4, 6], +# num_restormer_blocks=[2, 2, 2, 2], +# num_refinement_blocks=4, +# heads=[1, 2, 4, 8], +# window_size=[10, 10, 10, 10], mlp_ratio=4., +# qkv_bias=True, qk_scale=None, +# interval=[24, 12, 6, 3], +# ffn_expansion_factor = 2.66, +# LayerNorm_type = 'WithBias', ## Other option 'BiasFree' +# bias=False) + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/ARTfuse_layer.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/ARTfuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..78af0b27528b50db897936f6b8b4eee51d566b1c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/ARTfuse_layer.py @@ -0,0 +1,705 @@ +""" +ART layer in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input to the Attention layer:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + # print("q range:", q.max(), q.min()) + # print("k range:", k.max(), k.min()) + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + # print("attention matrix:", attn.shape, attn) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + # print("feature before proj-layer:", x.max(), x.min()) + x = self.proj(x) + # print("feature after proj-layer:", x.max(), x.min()) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pre_norm=True): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + # self.conv = nn.Conv2d(dim, dim, kernel_size=1) + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + # self.conv = ConvBlock(self.dim, self.dim, drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.bn_norm = nn.BatchNorm2d(dim) + self.pre_norm = pre_norm + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + # x = self.conv(x) + shortcut = x + if self.pre_norm: + # print("using pre_norm") + x = self.norm1(x) ## 归一化之后特征范围变得非常小 + x = x.view(B, H, W, C) + # print("normalized input range:", x.max(), x.min(), x.mean()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + # print("range of feature before layer:", x.max(), x.min(), x.mean()) + # x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + + # x_bn = self.bn_norm(x.permute(0, 3, 1, 2)) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + # print("[ART layer] input and output feature difference:", torch.mean(torch.abs(shortcut - x))) + # print("range of feature before ART layer:", shortcut.max(), shortcut.min(), shortcut.mean()) + # print("range of feature after ART layer:", x.max(), x.min(), x.mean()) + if self.pre_norm: + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + # print("using post_norm") + x = self.norm1(shortcut + self.drop_path(x)) + x = self.norm2(x + self.drop_path(self.mlp(x))) + # x = shortcut + self.drop_path(x) + # x = x + self.drop_path(self.mlp(x)) + + # print("range of fused feature:", x.max(), x.min()) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + return self.layers(image) + + + +########################################################################## +class Cross_TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +class Cross_TransformerBlock_v2(nn.Module): + r""" ART Transformer Block. + 将两个模态的特征沿着channel方向concat, 一起取window. 之后把取出来的两个模态中同一个window的所有patches合并,送到transformer处理。 + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + xy = torch.cat((x, y), -1) + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + xy = F.pad(xy, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + xy = xy.reshape(B, Hd // G, G, Wd // G, G, 2*C).permute(0, 1, 3, 2, 4, 5).contiguous() + xy = xy.reshape(B * Hd * Wd // G ** 2, G ** 2, 2*C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + xy = xy.reshape(B, Gh, I, Gw, I, 2*C).permute(0, 2, 4, 1, 3, 5).contiguous() + xy = xy.reshape(B * I * I, Gh * Gw, 2*C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + x = xy[:, :, :C] + y = xy[:, :, C:] + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/DataConsistency.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/DataConsistency.py new file mode 100644 index 0000000000000000000000000000000000000000..5f460e9acabcb33e1a3f812eb9acaeb337da446c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/DataConsistency.py @@ -0,0 +1,48 @@ +""" +Created: DataConsistency @ Xiyang Cai, 2023/09/09 + +Data consistency layer for k-space signal. + +Ref: DataConsistency in DuDoRNet (https://github.com/bbbbbbzhou/DuDoRNet) + +""" + +import torch +from torch import nn +from einops import repeat + + +def data_consistency(k, k0, mask): + """ + k - input in k-space + k0 - initially sampled elements in k-space + mask - corresponding nonzero location + """ + out = (1 - mask) * k + mask * k0 + return out + + +class DataConsistency(nn.Module): + """ + Create data consistency operator + """ + + def __init__(self): + super(DataConsistency, self).__init__() + + def forward(self, k, k0, mask): + """ + k - input in frequency domain, of shape (n, nx, ny, 2) + k0 - initially sampled elements in k-space + mask - corresponding nonzero location (n, 1, len, 1) + """ + + if k.dim() != 4: # input is 2D + raise ValueError("error in data consistency layer!") + + # mask = repeat(mask.squeeze(1, 3), 'b x -> b x y c', y=k.shape[1], c=2) + mask = torch.tile(mask, (1, mask.shape[2], 1, k.shape[-1])) ### [n, 320, 320, 2] + # print("k and k0 shape:", k.shape, k0.shape) + out = data_consistency(k, k0, mask) + + return out, mask diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/DuDo_mUnet.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/DuDo_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a56069a27f0cdd392482b41dc4e3b79252295487 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/DuDo_mUnet.py @@ -0,0 +1,359 @@ +""" +Xiaohan Xing, 2023/10/25 +传入两个模态的kspace data, 经过kspace network进行重建。 +然后把重建之后的kspace data变换回图像,然后送到image domain的网络进行图像重建。 +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + +from .Kspace_mUnet import kspace_mmUnet + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = kspace_ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = kspace_ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = kspace_ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = kspace_ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class DuDo_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 3, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + # self.kspace_net = mmConvKSpace() + self.kspace_net = kspace_mmUnet(args) + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + ) + ) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + stack = [] + feature_stack = [] + # output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + output = torch.cat((image, LR_image, aux_image), 1) #.detach() + + # print("LR image range:", LR_image.max(), LR_image.min()) + # print("kspace_recon image range:", image.max(), image.min()) + # print("aux_image range:", aux_image.max(), aux_image.min()) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output, image, recon_kspace, masked_kspace, data_stats + + + +class kspace_ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return DuDo_mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/DuDo_mUnet_ARTfusion.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/DuDo_mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6df67b63140ca51bd7b8664151ab4b1b7a6305d5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/DuDo_mUnet_ARTfusion.py @@ -0,0 +1,369 @@ +""" +Xiaohan Xing, 2023/11/08 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + ARTfusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + output1, output2 = aux_image.detach(), LR_image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, image, recon_kspace, masked_kspace, data_stats + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/DuDo_mUnet_CatFusion.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/DuDo_mUnet_CatFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c2b86b6e2031b1af07ab7cb58efc182b7a2673 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/DuDo_mUnet_CatFusion.py @@ -0,0 +1,338 @@ +""" +Xiaohan Xing, 2023/11/10 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + multi layer concat fusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_CATfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t2_gt: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + t2_image = t2_gt * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + # # print("aux_image range:", aux_image.max(), aux_image.min()) + # print("LR_image range:", LR_image.max(), LR_image.min()) + # print("kspace recon_img range:", image.max(), image.min()) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + t2_gt_image = self.normalize(t2_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + t2_gt_image = torch.clamp(t2_gt_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + # output1, output2 = aux_image.detach(), LR_image.detach() + output1, output2 = aux_image.detach(), image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, aux_image, t2_gt_image, image, LR_image, recon_kspace, masked_kspace, data_stats + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_CATfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/Kspace_ConvNet.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/Kspace_ConvNet.py new file mode 100644 index 0000000000000000000000000000000000000000..918cbe598a62653394c8c4eaf99cd10a45c064ca --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/Kspace_ConvNet.py @@ -0,0 +1,140 @@ +""" +2023/10/05 +Xiaohan Xing +Build a simple network with several convolution layers. +Input: concat (undersampled kspace of FS-PD, fully-sampled kspace of PD modality) +Output: interpolated kspace of FS-PD +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, args, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + # self.conv_blocks = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # self.conv_blocks.append(ConvBlock(self.chans, self.chans * 2, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans * 2, self.chans, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans, self.out_chans, drop_prob)) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + # print("output of the three conv blocks:", output1.shape, output2.shape, output3.shape) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("conv1_output range:", output1.max(), output1.min()) + # print("conv2_output range:", output2.max(), output2.min()) + # print("conv3_output range:", output3.max(), output3.min()) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + # print("output before DC layer:", output.shape) + # print("output range before DC layer:", output.max(), output.min()) + # output = output + kspace ## residual connection + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + # mask_output = output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +def build_model(args): + return mmConvKSpace(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/Kspace_mUnet.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/Kspace_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2db9a357798be4abd4cfbd44e9207555770b0456 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/Kspace_mUnet.py @@ -0,0 +1,202 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # print("output:", output.shape) + # print("kspace:", kspace.shape) + # print("mask:", mask.shape) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/Kspace_mUnet_new.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/Kspace_mUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ccfd28dc198ffeff2259a5fbed45dabdc76819 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/Kspace_mUnet_new.py @@ -0,0 +1,200 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # print("using Unet without Relu layers") + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + # output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # # print("output:", output.shape) + # # print("kspace:", kspace.shape) + # # print("mask:", mask.shape) + # mask_output = mask * output + + return output # .permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/MINet.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..4eeeb063b12be1dc594e8e076e6a59e454fcfa0b --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/MINet.py @@ -0,0 +1,356 @@ +""" +Compare with the method "Multi-contrast MRI Super-Resolution via a Multi-stage Integration Network". +The code is from https://github.com/chunmeifeng/MINet/edit/main/fastmri/models/MINet.py. +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F + + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 1 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + # print("modules_tail:", modules_tail) + # self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Sigmoid()) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + + + +class MINet(nn.Module): + def __init__(self, args, n_resgroups=3, n_resblocks=3, n_feats=64): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + # print("x1:", x1.shape, "x2:", x2.shape) + + ### 源代码中是super-resolution任务,所以self.tail对图像进行unsampling. 我们做reconstruction把这步去掉 + x2 = self.tail(x2) + + + resT1 = x1 + resT2 = x2 + + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + # print("t1s:", [item.shape for item in t1s]) + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + # print("resT1 range:", resT1.max(), resT1.min()) + # print("resT2 range:", resT2.max(), resT2.min()) + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1, x2 #x1=pd x2=pdfs + + +def build_model(args): + return MINet(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/MINet_common.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/MINet_common.py new file mode 100644 index 0000000000000000000000000000000000000000..631f3036db4edf8eab00d70c66d2058a3b65cd59 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/MINet_common.py @@ -0,0 +1,90 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias) + +class MeanShift(nn.Conv2d): + def __init__( + self, rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): + + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std + for p in self.parameters(): + p.requires_grad = False + +class BasicBlock(nn.Sequential): + def __init__( + self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, + bn=True, act=nn.ReLU(True)): + + m = [conv(in_channels, out_channels, kernel_size, bias=bias)] + if bn: + m.append(nn.BatchNorm2d(out_channels)) + if act is not None: + m.append(act) + + super(BasicBlock, self).__init__(*m) + +class ResBlock(nn.Module): + def __init__( + self, conv, n_feats, kernel_size, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + + return res + + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + # print("upsampler:", m) + + super(Upsampler, self).__init__(*m) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/SANet.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/SANet.py new file mode 100644 index 0000000000000000000000000000000000000000..bfac3590e248c9532b002885c6c0b7a55ab1400a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/SANet.py @@ -0,0 +1,392 @@ +""" +This script contains the codes of "Exploring Separable Attention for Multi-Contrast MR Image Super-Resolution (TNNLS 2023)" +https://github.com/chunmeifeng/SANet/blob/main/fastmri/models/SANet.py +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import os + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,scale,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups #10 + self.n_resblocks = n_resblocks #20 + self.n_feats = n_feats #64 + kernel_size = 3 + reduction = 16 #16 + # scale = args.scale[0] + + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class Seatt(nn.Module): + def __init__(self, in_c): + super(Seatt, self).__init__() + self.reduce = nn.Conv2d(in_c * 2, 32, 1) + self.ff_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.bf_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.rgbd_pred_layer = Pred_Layer(32 * 2) + self.convq = nn.Conv2d(32, 32, 3, 1, 1) + + def forward(self, rgb_feat, dep_feat, pred): + feat = torch.cat((rgb_feat, dep_feat), 1) + + # vis_attention_map_rgb_feat(rgb_feat[0][0]) + # vis_attention_map_dep_feat(dep_feat[0][0]) + # vis_attention_map_feat(feat[0][0]) + + + feat = self.reduce(feat) + [_, _, H, W] = feat.size() + pred = torch.sigmoid(pred) + + ni_pred = 1 - pred + + ff_feat = self.ff_conv(feat * pred) + bf_feat = self.bf_conv(feat * (1 - pred)) + new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1)) + + return new_pred + +class SANet(nn.Module): + def __init__(self, args, scale=1, n_resgroups=3, n_resblocks=3, n_feats=64): + super(SANet, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + + cs = [64,64,64,64] + self.Seatts = nn.ModuleList([Seatt(c) for c in cs]) + + self.net1 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + print("nlayer:",nlayer) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=3, padding=1) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2, Seatt, map_conv in zip(self.net1.body._modules.items(), self.net2.body._modules.items(), self.Seatts, self.map_convs): + + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + pred = map_conv(resT1) + res = Seatt(resT1,resT2,pred) + + resT2 = res+resT2 + + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + + return x1,x2 + + +def build_model(args): + return SANet(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/SwinFuse_layer.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/SwinFuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3dca8ba793d92961bc3f1d987e211fe542b303 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/SwinFuse_layer.py @@ -0,0 +1,1310 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + # print("x shape:", x.shape) + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + # print("num heads:", self.num_heads) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + # print("x_windows:", x_windows.shape) + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + + # # 不使用残差连接 + # x = self.residual_group_A(x, x_size) + # y = self.residual_group_B(y, x_size) + + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion_layer(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Fusion_num_heads=[6, 6], + window_size=7, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, img_range=1., resi_connection='1conv', + **kwargs): + super(SwinFusion_layer, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + + ### fusion layer. + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + # print("fusion layers:", len(self.layers_Fusion)) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + + def forward_features_Fusion(self, x, y): + + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + # x = torch.cat([x, y], 1) + # ## Downsample the feature in the channel dimension + # x = self.lrelu(self.conv_after_body_Fusion(x)) + # print("x range after fusion:", x.max(), x.min()) + # print("y range after fusion:", y.max(), y.min()) + # print("x:", x.shape) + + return x, y + + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + print("modality1 input:", x.max(), x.min()) + print("modality2 input:", y.max(), y.min()) + + # multi-modal feature fusion. + x, y = self.forward_features_Fusion(x, y) ### 返回两个模态融合之后的features. + print("modality1 output:", x.max(), x.min()) + print("modality2 output:", y.max(), y.min()) + + return x[:, :, :H, :W], y[:, :, :H, :W] + + + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + + +if __name__ == '__main__': + width, height = 120, 120 + swinfuse = SwinFusion_layer(img_size=(width, height), window_size=8, depths=[6, 6, 6], embed_dim=60, num_heads=[6, 6, 6]) + # print("SwinFusion layer:", swinfuse) + + A = torch.randn((4, 60, width, height)) + B = torch.randn((4, 60, width, height)) + + x = swinfuse(A, B) + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/SwinFusion.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/SwinFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..3e72c2a3e8859423f3544043e8b9feaec002d0e2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/SwinFusion.py @@ -0,0 +1,1468 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # print("qkv after reshape:", qkv.shape) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + # 不使用残差连接 + # print("input x shape:", x.shape) + x = self.residual_group_A(x, x_size) + y = self.residual_group_B(y, x_size) + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Re_depths=[4], + Ex_num_heads=[6], Fusion_num_heads=[6, 6], Re_num_heads=[6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinFusion, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + print('in_chans: ', in_chans) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + ####修改shallow feature extraction 网络, 修改为2个3x3的卷积#### + self.conv_first1_A = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first1_B = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first2_A = nn.Conv2d(embed_dim_temp, embed_dim, 3, 1, 1) + self.conv_first2_B = nn.Conv2d(embed_dim_temp, embed_dim_temp, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + self.Re_num_layers = len(Re_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + dpr_Re = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Re_depths))] # stochastic depth decay rule + # build Residual Swin Transformer blocks (RSTB) + self.layers_Ex_A = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_A.append(layer) + self.norm_Ex_A = norm_layer(self.num_features) + + self.layers_Ex_B = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_B.append(layer) + self.norm_Ex_B = norm_layer(self.num_features) + + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + self.layers_Re = nn.ModuleList() + for i_layer in range(self.Re_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Re_depths[i_layer], + num_heads=Re_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Re[sum(Re_depths[:i_layer]):sum(Re_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Re.append(layer) + self.norm_Re = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last1 = nn.Conv2d(embed_dim, embed_dim_temp, 3, 1, 1) + self.conv_last2 = nn.Conv2d(embed_dim_temp, int(embed_dim_temp/2), 3, 1, 1) + self.conv_last3 = nn.Conv2d(int(embed_dim_temp/2), num_out_ch, 3, 1, 1) + + self.tanh = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features_Ex_A(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_A: + x = layer(x, x_size) + + x = self.norm_Ex_A(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward_features_Ex_B(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_B: + x = layer(x, x_size) + + x = self.norm_Ex_B(x) # B L C + x = self.patch_unembed(x, x_size) + return x + + def forward_features_Fusion(self, x, y): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + # print("x token:", x.shape, "y token:", y.shape) + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + x = torch.cat([x, y], 1) + ## Downsample the feature in the channel dimension + x = self.lrelu(self.conv_after_body_Fusion(x)) + + return x + + def forward_features_Re(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Re: + x = layer(x, x_size) + + x = self.norm_Re(x) # B L C + x = self.patch_unembed(x, x_size) + # Convolution + x = self.lrelu(self.conv_last1(x)) + x = self.lrelu(self.conv_last2(x)) + x = self.conv_last3(x) + return x + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + # self.mean_A = self.mean.type_as(x) + # self.mean_B = self.mean.type_as(y) + # self.mean = (self.mean_A + self.mean_B) / 2 + + # x = (x - self.mean_A) * self.img_range + # y = (y - self.mean_B) * self.img_range + + # Feedforward + x = self.forward_features_Ex_A(x) + y = self.forward_features_Ex_B(y) + # print("x before fusion:", x.shape) + # print("y before fusion:", y.shape) + x = self.forward_features_Fusion(x, y) + x = self.forward_features_Re(x) + x = self.tanh(x) + + # x = x / self.img_range + self.mean + # print("H:", H, "upscale:", self.upscale) + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='convs') + + +if __name__ == '__main__': + model = SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='convs') + # print(model) + + A = torch.randn((1, 1, 240, 240)) + B = torch.randn((1, 1, 240, 240)) + + x = model(A, B) + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/TransFuse.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/TransFuse.py new file mode 100644 index 0000000000000000000000000000000000000000..56ca3d2b24f382603c876e4f6909363a9794ffa3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/TransFuse.py @@ -0,0 +1,155 @@ +import math +from collections import deque + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from torchvision import models + + +class SelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + """ + + def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(n_embd, n_embd) + self.query = nn.Linear(n_embd, n_embd) + self.value = nn.Linear(n_embd, n_embd) + # regularization + self.attn_drop = nn.Dropout(attn_pdrop) + self.resid_drop = nn.Dropout(resid_pdrop) + # output projection + self.proj = nn.Linear(n_embd, n_embd) + self.n_head = n_head + + def forward(self, x): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y + + +class Block(nn.Module): + """ an unassuming Transformer block """ + + def __init__(self, n_embd, n_head, block_exp, attn_pdrop, resid_pdrop): + super().__init__() + self.ln1 = nn.LayerNorm(n_embd) + self.ln2 = nn.LayerNorm(n_embd) + self.attn = SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + self.mlp = nn.Sequential( + nn.Linear(n_embd, block_exp * n_embd), + nn.ReLU(True), # changed from GELU + nn.Linear(block_exp * n_embd, n_embd), + nn.Dropout(resid_pdrop), + ) + + def forward(self, x): + B, T, C = x.size() + + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + + return x + + +class TransFuse_layer(nn.Module): + """ the full GPT language model, with a context size of block_size """ + + def __init__(self, n_embd, n_head, block_exp, n_layer, + num_anchors, seq_len=1, + embd_pdrop=0.1, attn_pdrop=0.1, resid_pdrop=0.1): + super().__init__() + self.n_embd = n_embd + self.seq_len = seq_len + self.vert_anchors = num_anchors + self.horz_anchors = num_anchors + + # positional embedding parameter (learnable), image + lidar + self.pos_emb = nn.Parameter(torch.zeros(1, 2 * seq_len * self.vert_anchors * self.horz_anchors, n_embd)) + + self.drop = nn.Dropout(embd_pdrop) + + # transformer + self.blocks = nn.Sequential(*[Block(n_embd, n_head, + block_exp, attn_pdrop, resid_pdrop) + for layer in range(n_layer)]) + + # decoder head + self.ln_f = nn.LayerNorm(n_embd) + + self.block_size = seq_len + self.apply(self._init_weights) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + + def forward(self, m1, m2): + """ + Args: + m1 (tensor): B*seq_len, C, H, W + m2 (tensor): B*seq_len, C, H, W + """ + + bz = m2.shape[0] // self.seq_len + h, w = m2.shape[2:4] + + # forward the image model for token embeddings + m1 = m1.view(bz, self.seq_len, -1, h, w) + m2 = m2.view(bz, self.seq_len, -1, h, w) + + # pad token embeddings along number of tokens dimension + token_embeddings = torch.cat([m1, m2], dim=1).permute(0,1,3,4,2).contiguous() + token_embeddings = token_embeddings.view(bz, -1, self.n_embd) # (B, an * T, C) + + # add (learnable) positional embedding for all tokens + x = self.drop(self.pos_emb + token_embeddings) # (B, an * T, C) + x = self.blocks(x) # (B, an * T, C) + x = self.ln_f(x) # (B, an * T, C) + x = x.view(bz, 2 * self.seq_len, self.vert_anchors, self.horz_anchors, self.n_embd) + x = x.permute(0,1,4,2,3).contiguous() # same as token_embeddings + + m1_out = x[:, :self.seq_len, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + m2_out = x[:, self.seq_len:, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + print("modality1 output:", m1_out.max(), m1_out.min()) + print("modality2 output:", m2_out.max(), m2_out.min()) + + return m1_out, m2_out + + +if __name__ == "__main__": + + feature1 = torch.randn((4, 512, 20, 20)) + feature2 = torch.randn((4, 512, 20, 20)) + model = TransFuse_layer(n_embd=512, n_head=4, block_exp=4, n_layer=8, num_anchors=20, seq_len=1) + print("TransFuse_layer:", model) + feat1, feat2 = model(feature1, feature2) + print(feat1.shape, feat2.shape) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/Unet_ART.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/Unet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b0f834270a921ae4eb428c92b3e782fbed19b3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/Unet_ART.py @@ -0,0 +1,243 @@ +""" +2023/07/06, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + output = conv_layer(output) + + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/__init__.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..334f1821a9b59e692641f70cdc61e7655b1a9c89 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/__init__.py @@ -0,0 +1,114 @@ +from .unet import build_model as UNET +from .mmunet import build_model as MUNET +from .humus_net import build_model as HUMUS +from .MINet import build_model as MINET +from .SANet import build_model as SANET +from .mtrans_net import build_model as MTRANS +from .unimodal_transformer import build_model as TRANS +from .cnn_transformer_new import build_model as CNN_TRANS +from .cnn import build_model as CNN +from .unet_transformer import build_model as UNET_TRANSFORMER +from .swin_transformer import build_model as SWIN_TRANS +from .restormer import build_model as RESTORMER + + +from .cnn_late_fusion import build_model as MULTI_CNN +from .mmunet_late_fusion import build_model as MMUNET_LATE +from .mmunet_early_fusion import build_model as MMUNET_EARLY +from .trans_unet.trans_unet import build_model as TRANSUNET +from .SwinFusion import build_model as SWINMULTI +from .mUnet_transformer import build_model as MUNET_TRANSFORMER +from .munet_multi_transfuse import build_model as MUNET_TRANSFUSE +from .munet_swinfusion import build_model as MUNET_SWINFUSE +from .munet_multi_concat import build_model as MUNET_CONCAT +from .munet_concat_decomp import build_model as MUNET_CONCAT_DECOMP +from .munet_multi_sum import build_model as MUNET_SUM +# from .mUnet_ARTfusion_new import build_model as MUNET_ART_FUSION +from .mUnet_ARTfusion import build_model as MUNET_ART_FUSION +from .mUnet_Restormer_fusion import build_model as MUNET_RESTORMER_FUSION +from .mUnet_ARTfusion_SeqConcat import build_model as MUNET_ART_FUSION_SEQ + +from .ARTUnet import build_model as ART +from .Unet_ART import build_model as UNET_ART +from .unet_restormer import build_model as UNET_RESTORMER +from .mmunet_ART import build_model as MMUNET_ART +from .mmunet_restormer import build_model as MMUNET_RESTORMER +from .mmunet_restormer_v2 import build_model as MMUNET_RESTORMER_V2 +from .mmunet_restormer_3blocks import build_model as MMUNET_RESTORMER_SMALL + +from .mARTUnet import build_model as ART_MULTI_INPUT + +from .ART_Restormer import build_model as ART_RESTORMER +from .ART_Restormer_v2 import build_model as ART_RESTORMER_V2 +from .swinIR import build_model as SWINIR + +from .Kspace_ConvNet import build_model as KSPACE_CONVNET +from .Kspace_mUnet_new import build_model as KSPACE_MUNET +from .kspace_mUnet_concat import build_model as KSPACE_MUNET_CONCAT +from .kspace_mUnet_AttnFusion import build_model as KSPACE_MUNET_ATTNFUSE + +from .DuDo_mUnet import build_model as DUDO_MUNET +from .DuDo_mUnet_ARTfusion import build_model as DUDO_MUNET_ARTFUSION +from .DuDo_mUnet_CatFusion import build_model as DUDO_MUNET_CONCAT + + +model_factory = { + 'unet_single': UNET, + 'humus_single': HUMUS, + 'transformer_single': TRANS, + 'cnn_transformer': CNN_TRANS, + 'cnn_single': CNN, + 'swin_trans_single': SWIN_TRANS, + 'trans_unet': TRANSUNET, + 'unet_transformer': UNET_TRANSFORMER, + 'restormer': RESTORMER, + 'unet_art': UNET_ART, + 'unet_restormer': UNET_RESTORMER, + 'art': ART, + 'art_restormer': ART_RESTORMER, + 'art_restormer_v2': ART_RESTORMER_V2, + 'swinIR': SWINIR, + + + 'munet_transformer': MUNET_TRANSFORMER, + 'munet_transfuse': MUNET_TRANSFUSE, + 'cnn_late_multi': MULTI_CNN, + 'unet_multi': MUNET, + 'unet_late_multi':MMUNET_LATE, + 'unet_early_multi':MMUNET_EARLY, + 'munet_ARTfusion': MUNET_ART_FUSION, + 'munet_restormer_fusion': MUNET_RESTORMER_FUSION, + + + 'minet_multi': MINET, + 'sanet_multi': SANET, + 'mtrans_multi': MTRANS, + 'swin_fusion': SWINMULTI, + 'munet_swinfuse': MUNET_SWINFUSE, + 'munet_concat': MUNET_CONCAT, + 'munet_concat_decomp': MUNET_CONCAT_DECOMP, + 'munet_sum': MUNET_SUM, + 'mmunet_art': MMUNET_ART, + 'mmunet_restormer': MMUNET_RESTORMER, + 'mmunet_restormer_small': MMUNET_RESTORMER_SMALL, + 'mmunet_restormer_v2': MMUNET_RESTORMER_V2, + + 'art_multi_input': ART_MULTI_INPUT, + + 'munet_ARTfusion_SeqConcat': MUNET_ART_FUSION_SEQ, + + 'kspace_ConvNet': KSPACE_CONVNET, + 'kspace_munet': KSPACE_MUNET, + 'kspace_munet_concat': KSPACE_MUNET_CONCAT, + 'kspace_munet_AttnFusion': KSPACE_MUNET_ATTNFUSE, + + 'dudo_munet': DUDO_MUNET, + 'dudo_munet_ARTfusion': DUDO_MUNET_ARTFUSION, + 'dudo_munet_concat': DUDO_MUNET_CONCAT, +} + + +def build_model_from_name(args): + assert args.model_name in model_factory.keys(), 'unknown model name' + + return model_factory[args.model_name](args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/cnn.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5c130178e7cdabaaef8686da0cec341265f3c43d --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/cnn.py @@ -0,0 +1,246 @@ +""" +把our method中的单模态CNN_Transformer中的transformer换成几个conv_block用来下采样提取特征, 看是否能够达到和Unet相似的效果。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Reconstructor(nn.Module): + def __init__(self): + super(CNN_Reconstructor, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + nlatent = self.encoder.output_dim + self.decoder = Decoder(n_upsample=4, n_res=1, dim=nlatent, output_dim=1, pad_type='zero') + + + def forward(self, m): + #4,1,240,240 + feature_output = self.encoder.model(m) # [4, 256, 60, 60] + # print("output of the encoder:", feature_output.shape) + + output = self.decoder(feature_output) + + return output + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Reconstructor() diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/cnn_late_fusion.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/cnn_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..58dacad6910b6e896d00a0d36c9ea60e8d77217f --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/cnn_late_fusion.py @@ -0,0 +1,196 @@ +""" +在lequan的代码上,去掉transformer-based feature fusion部分。 +直接用CNN提取特征之后concat作为multi-modal representation. 用普通的decoder进行图像重建。 +把T1作为guidance modality, T2作为target modality. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers1 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + self.down_sample_layers2 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers1.append(ConvBlock(ch, ch * 2, drop_prob)) + self.down_sample_layers2.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + + # self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # ch = chans + # for _ in range(num_pool_layers - 1): + # self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + # ch *= 2 + # self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv = nn.Sequential(nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), nn.Tanh()) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = torch.cat([image, aux_image], 1) + # for layer in self.down_sample_layers: + # output = layer(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + + # apply down-sampling layers + output = image + for layer in self.down_sample_layers1: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + output_m1 = output + + output = aux_image + for layer in self.down_sample_layers2: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + output_m2 = output + + # print("two modality outputs:", output_m1.shape, output.shape) + output = torch.cat([output_m1, output_m2], 1) + + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv in self.up_transpose_conv: + output = transpose_conv(output) + + output = self.up_conv(output) + # print("model output:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/cnn_transformer.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/cnn_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..261d4d214e35c3a9be5d28bc426ce7280124723c --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/cnn_transformer.py @@ -0,0 +1,289 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=2, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=2, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 60 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 4 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + feature_output = self.upsampling1(feature_output) # [4,512,30,30] + feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/cnn_transformer_new.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/cnn_transformer_new.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b2dea1e7cc8d456a684b6c839f052383999e84 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/cnn_transformer_new.py @@ -0,0 +1,290 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +2023/05/23: 之前是从尺寸为60*60的特征图上切patch, 现在可以尝试直接用conv_blocks提取15*15的特征图,然后按照patch_size = 1切成225个patches。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=4, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 1 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + # feature_output = self.upsampling1(feature_output) # [4,512,30,30] + # feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/humus_net.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/humus_net.py new file mode 100644 index 0000000000000000000000000000000000000000..d23701634f849b414866d71b3c23c6d07f3d6437 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/humus_net.py @@ -0,0 +1,1070 @@ +""" +HUMUS-Block +Hybrid Transformer-convolutional Multi-scale denoiser. + +We use parts of the code from +SwinIR (https://github.com/JingyunLiang/SwinIR/blob/main/models/network_swinir.py) +Swin-Unet (https://github.com/HuCaoFighting/Swin-Unet/blob/main/networks/swin_transformer_unet_skip_expand_decoder_sys.py) + +HUMUS-Net中是对kspace data操作, 多个unrolling cascade, 每个cascade中都利用HUMUS-block对图像进行denoising. +因为我们的方法输入是Under-sampled images, 所以不需要进行unrolling, 直接使用一个HUMUS-block进行MRI重建。 +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from einops import rearrange + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # print("x shape:", x.shape) + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + +class PatchExpandSkip(nn.Module): + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.expand = nn.Linear(dim, 2*dim, bias=False) + self.channel_mix = nn.Linear(dim, dim // 2, bias=False) + self.norm = norm_layer(dim // 2) + + def forward(self, x, skip): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) + x = x.view(B,-1,C//4) + x = torch.cat([x, skip], dim=-1) + x = self.channel_mix(x) + x = self.norm(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, torch.tensor(x_size)) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + block_type: + D: downsampling block, + U: upsampling block + B: bottleneck block + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv', block_type='B'): + super(RSTB, self).__init__() + self.dim = dim + conv_dim = dim // (patch_size ** 2) + divide_out_ch = 1 + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + ) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(conv_dim, conv_dim // divide_out_ch, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(conv_dim, conv_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // divide_out_ch, 3, 1, 1)) + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.block_type = block_type + if block_type == 'B': + self.reshape = torch.nn.Identity() + elif block_type == 'D': + self.reshape = PatchMerging(input_resolution, dim) + elif block_type == 'U': + self.reshape = PatchExpandSkip([res // 2 for res in input_resolution], dim * 2) + else: + raise ValueError('Unknown RSTB block type.') + + def forward(self, x, x_size, skip=None): + if self.block_type == 'U': + assert skip is not None, "Skip connection is required for patch expand" + x = self.reshape(x, skip) + + out = self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size)) + block_out = self.patch_embed(out) + x + + if self.block_type == 'D': + block_out = (self.reshape(block_out), block_out) # return skip connection + + return block_out + +class PatchEmbed(nn.Module): + r""" Fixed Image to Patch Embedding. Flattens image along spatial dimensions + without learned projection mapping. Only supports 1x1 patches. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + embed_dim (int): Number of output channels. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + +class PatchEmbedLearned(nn.Module): + r""" Image to Patch Embedding with arbitrary patch size and learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Linear(in_chans * patch_size[0] * patch_size[1], embed_dim) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.patch_to_vec(x) + x = self.proj(x) + if self.norm is not None: + x = self.norm(x) + return x + + def patch_to_vec(self, x): + """ + Args: + x: (B, C, H, W) + Returns: + patches: (B, num_patches, patch_size * patch_size * C) + """ + B, C, H, W = x.shape + x = x.permute(0, 2, 3, 1).contiguous() + x = x.view(B, H // self.patch_size[0], self.patch_size[0], W // self.patch_size[1], self.patch_size[1], C) + patches = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.patch_size[0], self.patch_size[1], C) + patches = patches.contiguous().view(B, self.num_patches, self.patch_size[0] * self.patch_size[1] * C) + return patches + + +class PatchUnEmbed(nn.Module): + r""" Fixed Image to Patch Unembedding via reshaping. Only supports patch size 1x1. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + +class PatchUnEmbedLearned(nn.Module): + r""" Patch Embedding to Image with learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans=None, out_chans=None, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.out_chans = out_chans + + self.proj_up = nn.Linear(in_chans, in_chans * patch_size[0] * patch_size[1]) + + def forward(self, x, x_size=None): + B, HW, C = x.shape + x = self.proj_up(x) + x = self.vec_to_patch(x) + return x + + def vec_to_patch(self, x): + B, HW, C = x.shape + x = x.view(B, self.patches_resolution[0], self.patches_resolution[1], C).contiguous() + x = x.view(B, self.patches_resolution[0], self.patch_size[0], self.patches_resolution[1], self.patch_size[1], C // (self.patch_size[0] * self.patch_size[1])).contiguous() + x = x.view(B, self.img_size[0], self.img_size[1], -1).contiguous().permute(0, 3, 1, 2) + return x + + +class HUMUSNet(nn.Module): + r""" HUMUS-Block + Args: + img_size (int | tuple(int)): Input image size. + patch_size (int | tuple(int)): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depths (tuple(int)): Depth of each Swin Transformer layer in encoder and decoder paths. + num_heads (tuple(int)): Number of attention heads in different layers of encoder and decoder. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + ape (bool): If True, add absolute position embedding to the patch embedding. + patch_norm (bool): If True, add normalization after patch embedding. + use_checkpoint (bool): Whether to use checkpointing to save memory. + img_range: Image range. 1. or 255. + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + conv_downsample_first: use convolutional downsampling before MUST to reduce compute load on Transformers + """ + + def __init__(self, args, + img_size=[240, 240], + in_chans=1, + patch_size=1, + embed_dim=66, + depths=[2, 2, 2], + num_heads=[3, 6, 12], + window_size=5, + mlp_ratio=2., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + img_range=1., + resi_connection='1conv', + bottleneck_depth=2, + bottleneck_heads=24, + conv_downsample_first=True, + out_chans=1, + no_residual_learning=False, + **kwargs): + super(HUMUSNet, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans if out_chans is None else out_chans + + self.center_slice_out = (out_chans == 1) + + self.img_range = img_range + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + self.conv_downsample_first = conv_downsample_first + self.no_residual_learning = no_residual_learning + + ##################################################################################################### + ################################### 1, input block ################################### + input_conv_dim = embed_dim + self.conv_first = nn.Conv2d(num_in_ch, input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** self.num_layers) + self.mlp_ratio = mlp_ratio + + # Downsample for low-res feature extraction + if self.conv_downsample_first: + img_size = [im //2 for im in img_size] + self.conv_down_block = ConvBlock(input_conv_dim // 2, input_conv_dim, 0.0) + self.conv_down = DownsampConvBlock(input_conv_dim, input_conv_dim) + + # split image into non-overlapping patches + if patch_size > 1: + self.patch_embed = PatchEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + if patch_size > 1: + self.patch_unembed = PatchUnEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, out_chans=embed_dim, norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # Build MUST + # encoder + self.layers_down = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** i_layer) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='D', + ) + self.layers_down.append(layer) + + # bottleneck + dim_scaler = (2 ** self.num_layers) + self.layer_bottleneck = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=bottleneck_depth, + num_heads=bottleneck_heads, + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='B', + ) + + # decoder + self.layers_up = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** (self.num_layers - i_layer - 1)) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[(self.num_layers-1-i_layer)], + num_heads=num_heads[(self.num_layers-1-i_layer)], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='U', + ) + self.layers_up.append(layer) + + self.norm_down = norm_layer(self.num_features) + self.norm_up = norm_layer(self.embed_dim) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(input_conv_dim, input_conv_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(input_conv_dim, input_conv_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim, 3, 1, 1)) + + # Upsample if needed + if self.conv_downsample_first: + self.conv_up_block = ConvBlock(input_conv_dim, input_conv_dim // 2, 0.0) + self.conv_up = TransposeConvBlock(input_conv_dim, input_conv_dim // 2) + + ##################################################################################################### + ################################ 3, output block ################################ + self.conv_last = nn.Conv2d(input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, num_out_ch, 3, 1, 1) + self.output_norm = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + + # divisible by window size + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + + # divisible by total downsampling ratio + # this could be done more efficiently by combining the two + total_downsamp = int(2 ** (self.num_layers - 1)) + pad_h = (total_downsamp - h % total_downsamp) % total_downsamp + pad_w = (total_downsamp - w % total_downsamp) % total_downsamp + x = F.pad(x, (0, pad_w, 0, pad_h)) + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + # encode + skip_cons = [] + for i, layer in enumerate(self.layers_down): + x, skip = layer(x, x_size=layer.input_resolution) + skip_cons.append(skip) + x = self.norm_down(x) # B L C + + # bottleneck + x = self.layer_bottleneck(x, self.layer_bottleneck.input_resolution) + + # decode + for i, layer in enumerate(self.layers_up): + x = layer(x, x_size=layer.input_resolution, skip=skip_cons[-i-1]) + x = self.norm_up(x) + x = self.patch_unembed(x, x_size) + return x + + def forward(self, x): + # print("input shape:", x.shape) + # print("input range:", x.max(), x.min()) + C, H, W = x.shape[1:] + center_slice = (C - 1) // 2 + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.conv_downsample_first: + x_first = self.conv_first(x) + x_down = self.conv_down(self.conv_down_block(x_first)) + res = self.conv_after_body(self.forward_features(x_down)) + res = self.conv_up(res) + res = torch.cat([res, x_first], dim=1) + res = self.conv_up_block(res) + + res = self.conv_last(res) + + if self.no_residual_learning: + x = res + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + res + else: + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + + if self.no_residual_learning: + x = self.conv_last(res) + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + self.conv_last(res) + + # print("output range before scaling:", x.max(), x.min()) + # print("self.mean:", self.mean) + + if self.center_slice_out: + x = x / self.img_range + self.mean[:, center_slice, ...].unsqueeze(1) + else: + x = x / self.img_range + self.mean + + x = self.output_norm(x) + + return x[:, :, :H, :W] + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + +class DownsampConvBlock(nn.Module): + """ + A Downsampling Convolutional Block that consists of one strided convolution + layer followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.Conv2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H/2, W/2)`. + """ + return self.layers(image) + + + +def build_model(args): + return HUMUSNet(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/kspace_mUnet_AttnFusion.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/kspace_mUnet_AttnFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..02588c24559c0604ddbbd03a14b3b32e6cff2c8d --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/kspace_mUnet_AttnFusion.py @@ -0,0 +1,300 @@ +""" +Xiaohan Xing, 2023/11/22 +两个模态分别用CNN提取多个层级的特征, 每个层级都把特征沿着宽度拼接起来,得到(h, 2w, c)的特征。 +然后学习得到(h, 2w)的attention map, 和原始的特征相乘送到后面的层。 +""" + +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.attn_layers = nn.ModuleList() + + for l in range(self.num_pool_layers): + + self.attn_layers.append(nn.Sequential(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob), + nn.Conv2d(chans*(2**l), 2, kernel_size=1, stride=1), + nn.Sigmoid())) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, attn_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.attn_layers): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat之后求attention, 然后进行modulation. + attn = attn_layer(torch.cat((output1, output2), 1)) + # print("spatial attention:", attn[:, 0, :, :].unsqueeze(1).shape) + # print("output1:", output1.shape) + output1 = output1 * (1 + attn[:, 0, :, :].unsqueeze(1)) + output2 = output2 * (1 + attn[:, 1, :, :].unsqueeze(1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/kspace_mUnet_concat.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/kspace_mUnet_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..d16bfdc83305d3e1f490e5ca4b445c6d730557ce --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/kspace_mUnet_concat.py @@ -0,0 +1,351 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Predict the low-frequency mask ############## +class Mask_Predictor(nn.Module): + def __init__(self, in_chans, out_chans, chans, drop_prob): + super(Mask_Predictor, self).__init__() + self.conv1 = ConvBlock(in_chans, chans, drop_prob) + self.conv2 = ConvBlock(chans, chans*2, drop_prob) + self.conv3 = ConvBlock(chans*2, chans*4, drop_prob) + + # self.conv4 = ConvBlock(chans*4, 1, drop_prob) + + self.FC = nn.Linear(chans*4, out_chans) + self.act = nn.Sigmoid() + + + def forward(self, x): + ### 三层卷积和down-sampling. + # print("[Mask predictor] input data range:", x.max(), x.min()) + output = F.avg_pool2d(self.conv1(x), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer1_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv2(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer2_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv3(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer3_feature range:", output.max(), output.min()) + + feature = F.adaptive_avg_pool2d(F.relu(output), (1, 1)).squeeze(-1).squeeze(-1) + # print("mask_prediction feature:", output[:, :5]) + # print("[Mask predictor] bottleneck feature range:", feature.max(), feature.min()) + + output = self.FC(feature) + # print("output before act:", output) + output = self.act(output) + # print("mask output:", output) + + return output + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.mask_predictor1 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + self.mask_predictor2 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + ### 预测做low-freq DC的mask区域. + ### 输出三个维度, 前两维是坐标,最后一维是mask内部的权重 + t1_mask = self.mask_predictor1(output1) + t2_mask = self.mask_predictor2(output2) + + # print("t1_mask and t2_mask:", t1_mask.shape, t2_mask.shape) + t1_mask_coords, t2_mask_coords = t1_mask[:, :2], t2_mask[:, :2] + t1_DC_weight, t2_DC_weight = t1_mask[:, -1], t2_mask[:, -1] + + + # ### 预测做low-freq DC的DC weight + # t1_DC_weight = self.mask_predictor1(output1) + # t2_DC_weight = self.mask_predictor2(output2) + + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, t1_mask_coords, t2_mask_coords, t1_DC_weight, t2_DC_weight ## + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mARTUnet.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9607d41ca14562a94e542a1804af5aeaf15bc123 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mARTUnet.py @@ -0,0 +1,563 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +2023/07/20, Xiaohan Xing +将两个模态的图像在输入端concat, 得到num_channels = 2的input, 然后送到ART模型进行T2 modality的重建。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + # print("feature after layer_norm:", x.max(), x.min()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=2, + out_channels=1, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img, aux_img): + stack = [] + bs, _, H, W = inp_img.shape + fuse_img = torch.cat((inp_img, aux_img), 1) + inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=2, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mUnet_ARTfusion.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..610b1e1e38d74f695a3a19f22a6a80f3152bf713 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mUnet_ARTfusion.py @@ -0,0 +1,305 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import CS_TransformerBlock as TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + # output1, output2 = image1, image2 + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mUnet_ARTfusion_SeqConcat.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mUnet_ARTfusion_SeqConcat.py new file mode 100644 index 0000000000000000000000000000000000000000..55280925415f1db47cf746b59fb9710d47c9bff2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mUnet_ARTfusion_SeqConcat.py @@ -0,0 +1,374 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any + +import matplotlib.pyplot as plt +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet_new import ConcatTransformerBlock, TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion_SeqConcat(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4. + ): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.vis_feature_maps = False + self.vis_attention_maps = False + # self.vis_feature_maps = args.MODEL.VIS_FEAT_MAPS + # self.vis_attention_maps = args.MODEL.VIS_ATTN_MAPS + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + if l < 2: + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * (2 ** (l + 1))), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + else: + self.fuse_layers.append(nn.ModuleList([ConcatTransformerBlock(dim=int(chans * (2 ** l)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + feature_maps = { + 'modal1': {'before': [], 'after': []}, + 'modal2': {'before': [], 'after': []} + } + + attention_maps = [] + """ + attention_maps = [ + unet layer 1 attention maps (x2): [map from TransBlock1, TransBlock2, ...], + unet layer 2 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 3 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 4 attention maps (x6): [map from TransBlock1, TransBlock2, ...], + ] + """ + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + if self.vis_feature_maps: + feature_maps['modal1']['before'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['before'].append(output2.detach().cpu().numpy()) + + ### 两个模态的特征concat + # output = torch.cat((output1, output2), 1) + + + + unet_layer_attn_maps = [] + """ + unet_layer_attn_maps: [ + attn_map from TransformerBlock1: (B, nHGroups, nWGroups, winSize, winSize), + attn_map from TransformerBlock2: (B, nHGroups, nWGroups, winSize, winSize), + ... + ] + """ + if l < 2: + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + else: + output1 = rearrange(output1, "b c h w -> b (h w) c").contiguous() + output2 = rearrange(output2, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + if l < 2: + if self.vis_attention_maps: + output, attn_map = layer(output, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output = layer(output, [H // (2**l), W // (2**l)]) + + else: + if self.vis_attention_maps: + output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + attention_maps.append(unet_layer_attn_maps) + + if l < 2: + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + else: + output1 = rearrange(output1, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output2 = rearrange(output2, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + # if self.vis_attention_maps: + # output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) # attn_map: (B, nHGroups, nWGroups, winSize, winSize) + # unet_layer_attn_maps.append(attn_map) + # else: + # output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + # attention_maps.append(unet_layer_attn_maps) + + + + if self.vis_feature_maps: + feature_maps['modal1']['after'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['after'].append(output2.detach().cpu().numpy()) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + if self.vis_feature_maps: + return output1, output2, feature_maps#, relation_stack1, relation_stack2 #, t1_features, t2_features + elif self.vis_attention_maps: + return output1, output2, attention_maps + else: + return output1, output2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion_SeqConcat(args) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mUnet_Restormer_fusion.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mUnet_Restormer_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8ab29b641fa0231ad0d163224a4a90b44c32a9 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mUnet_Restormer_fusion.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .restormer import TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8]): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.Sequential(*[TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[0], ffn_expansion_factor=2.66, \ + bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + ### 将encoder中multi-modal fusion之后的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = fuse_layer(output) + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mUnet_transformer.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mUnet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2a209758203b198502154b75496333ef91717a --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mUnet_transformer.py @@ -0,0 +1,387 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/06/08 +分别用CNN提取两个模态的特征,然后送到transformer进行特征交互, 将经过transformer提取的特征和之前CNN提取的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Decoder_wo_skip(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder_wo_skip, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + output = transpose_conv(output) + output = conv(output) + + return output + + + + + +class mUnet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch*2 + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1, output2 = image1, image2 + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack1, output1 = self.encoder1(output1) ### output size: [4, 256, 15, 15] + stack2, output2 = self.encoder2(output2) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output1) # [4, 225, 256], 225 is the number of patches. + patch_embed_input1 = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + n_embed = self.to_patch_embedding(output2) # [4, 225, 256], 225 is the number of patches. + patch_embed_input2 = F.normalize(n_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input1, patch_embed_input2), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1]/2)), int(np.sqrt(feature_output.shape[1]/2)) + patch_num = feature_output.shape[1] + feature_output1 = feature_output[:, :patch_num//2, :] + feature_output2 = feature_output[:, patch_num//2:, :] + feature_output1 = feature_output1.contiguous().view(b, feature_output1.shape[-1], h, w) # [4,c,15,15] + feature_output2 = feature_output2.contiguous().view(b, feature_output2.shape[-1], h, w) + + output = torch.cat((output1, output2), 1) + torch.cat((feature_output1, feature_output2), 1) + # print("CNN feature range:", output1.max(), output1.min()) + # print("Transformer feature range:", feature_output1.max(), feature_output1.min()) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_Transformer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ee0c4784d04638df4ca259613777dde34980a2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet.py @@ -0,0 +1,201 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/04 +Concatenate the input images from different modalities and input it into the UNet. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_ART.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..1a09f2900a8924cdc1a7c373270e6ff8aadcfa45 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_ART.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_ART_v2.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_ART_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..adfe6bc0d9192a126f39932b40443badb5b60068 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_ART_v2.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_early_fusion.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_early_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..d062f057a2080a471005125a00ee27453a4b0706 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_early_fusion.py @@ -0,0 +1,223 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +encapsulate the encoder and decoder for the Unet model. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + stack, output = self.encoder(output) + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + output = self.decoder(output, stack) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_late_fusion.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..2b48e7d4445afd6a5bc1eaa2b065c37431585e94 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_late_fusion.py @@ -0,0 +1,229 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +use Unet for the reconstruction of each modality, concat the intermediate features of both modalities. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("decoder output:", output.shape) + + return output + + + + +class mmUnet_late(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack1, output1 = self.encoder1(image) + stack2, output2 = self.encoder2(aux_image) + ### fuse two modalities to get multi-modal representation. + output = torch.cat([output1, output2], 1) + output = self.conv(output) ### from [4, 512, 15, 15] to [4, 512, 15, 15] + + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet_late(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_restormer.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..88c7920d65c58ab82a00309a1ae501a8d5b33804 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_restormer.py @@ -0,0 +1,237 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_restormer_3blocks.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_restormer_3blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e3afe68612727d3195872a121ec8040fb0f66ef2 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_restormer_3blocks.py @@ -0,0 +1,235 @@ +""" +2023/07/18, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +之前 3 blocks的模型深层特征存在严重的gradient vanishing问题。现在把模型的encoder和decoder都改成只有3个blocks, 看是否还存在gradient vanishing问题。 +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 3, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2] + heads=[1, 2, 4] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_restormer_v2.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..4b51d3a218057206863230973d557173e891d3cd --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mmunet_restormer_v2.py @@ -0,0 +1,220 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +这个版本中是将Unet encoder提取的特征直接连接到decoder, 经过restormer refine的特征只是传递到下一层的encoder block. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + output = feature + + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mtrans_mca.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mtrans_mca.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4c89df4af71e986fa7c24d522e6732371f3128 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mtrans_mca.py @@ -0,0 +1,119 @@ +import torch +from torch import nn + +from .mtrans_transformer import Mlp + +# implement by YunluYan + + +class LinearProject(nn.Module): + def __init__(self, dim_in, dim_out, fn): + super(LinearProject, self).__init__() + need_projection = dim_in != dim_out + self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity() + self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity() + self.fn = fn + + def forward(self, x, *args, **kwargs): + x = self.project_in(x) + x = self.fn(x, *args, **kwargs) + x = self.project_out(x) + return x + + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, dim, num_heads, attn_drop=0., proj_drop=0.): + super(MultiHeadCrossAttention, self).__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim, dim*2, bias=False) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + + def forward(self, x, complement): + + # x [B, HW, C] + B_x, N_x, C_x = x.shape + + x_copy = x + + complement = torch.cat([x, complement], 1) + + B_c, N_c, C_c = complement.shape + + # q [B, heads, HW, C//num_heads] + q = self.to_q(x).reshape(B_x, N_x, self.num_heads, C_x//self.num_heads).permute(0, 2, 1, 3) + kv = self.to_kv(complement).reshape(B_c, N_c, 2, self.num_heads, C_c//self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_x, N_x, C_x) + + x = x + x_copy + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossTransformerEncoderLayer(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio = 1., attn_drop=0., proj_drop=0.,drop_path = 0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super(CrossTransformerEncoderLayer, self).__init__() + self.x_norm1 = norm_layer(dim) + self.c_norm1 = norm_layer(dim) + + self.attn = MultiHeadCrossAttention(dim, num_heads, attn_drop, proj_drop) + + self.x_norm2 = norm_layer(dim) + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop) + + self.drop1 = nn.Dropout(drop_path) + self.drop2 = nn.Dropout(drop_path) + + def forward(self, x, complement): + x = self.x_norm1(x) + complement = self.c_norm1(complement) + + x = x + self.drop1(self.attn(x, complement)) + x = x + self.drop2(self.mlp(self.x_norm2(x))) + return x + + + +class CrossTransformer(nn.Module): + def __init__(self, x_dim, c_dim, depth, num_heads, mlp_ratio =1., attn_drop=0., proj_drop=0., drop_path =0.): + super(CrossTransformer, self).__init__() + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + LinearProject(x_dim, c_dim, CrossTransformerEncoderLayer(c_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)), + LinearProject(c_dim, x_dim, CrossTransformerEncoderLayer(x_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)) + ])) + + def forward(self, x, complement): + for x_attn_complemnt, complement_attn_x in self.layers: + x = x_attn_complemnt(x, complement=complement) + x + complement = complement_attn_x(complement, complement=x) + complement + return x, complement + + + + + + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mtrans_net.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mtrans_net.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e80dadb45725603c7ff1da664a45533575db25 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mtrans_net.py @@ -0,0 +1,145 @@ +""" +This script implement the paper "Multi-Modal Transformer for Accelerated MR Imaging (TMI 2022)" +""" + + +import torch + +from torch import nn +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler +from .mtrans_mca import CrossTransformer + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(INPUT_DIM, HEAD_HIDDEN_DIM): + return ReconstructionHead(INPUT_DIM, HEAD_HIDDEN_DIM) + + +class CrossCMMT(nn.Module): + + def __init__(self, args): + super(CrossCMMT, self).__init__() + + INPUT_SIZE = 240 + INPUT_DIM = 1 # the channel of input + OUTPUT_DIM = 1 # the channel of output + HEAD_HIDDEN_DIM = 16 # the hidden dim of Head + TRANSFORMER_DEPTH = 4 # the depth of the transformer + TRANSFORMER_NUM_HEADS = 4 # the head's num of multi head attention + TRANSFORMER_MLP_RATIO = 3 # the MLP RATIO Of transformer + TRANSFORMER_EMBED_DIM = 256 # the EMBED DIM of transformer + P1 = 8 + P2 = 16 + CTDEPTH = 4 + + + self.head = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + self.head2 = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + + x_patch_dim = HEAD_HIDDEN_DIM * P1 ** 2 + x_num_patches = (INPUT_SIZE // P1) ** 2 + + complement_patch_dim = HEAD_HIDDEN_DIM * P2 ** 2 + complement_num_patches = (INPUT_SIZE // P2) ** 2 + + self.x_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P1, + p2=P1), + ) + + self.complement_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P2, + p2=P2), + ) + + self.x_pos_embedding = nn.Parameter(torch.randn(1, x_num_patches, x_patch_dim)) + self.complement_pos_embedding = nn.Parameter(torch.randn(1, complement_num_patches, complement_patch_dim)) + + ### + self.cross_transformer = CrossTransformer(x_patch_dim, complement_patch_dim, CTDEPTH, + TRANSFORMER_NUM_HEADS, TRANSFORMER_MLP_RATIO) + + self.p1 = P1 + self.p2 = P2 + + self.tail1 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + self.tail2 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x, complement): + x = self.head(x) + complement = self.head2(complement) + + x = x + complement + + b, _, h, w = x.shape + + _, _, c_h, c_w = complement.shape + + ### 得到每个模态的patch embedding (加上了position encoding) + x = self.x_patch_embbeding(x) + x += self.x_pos_embedding + + complement = self.complement_patch_embbeding(complement) + complement += self.complement_pos_embedding + + ### 将两个模态的patch embeddings送到cross transformer中提取特征。 + x, complement = self.cross_transformer(x, complement) + + c = int(x.shape[2] / (self.p1 * self.p1)) + H = int(h / self.p1) + W = int(w / self.p1) + + x = x.reshape(b, H, W, self.p1, self.p1, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w, ) + x = self.tail1(x) + + complement = complement.reshape(b, int(c_h/self.p2), int(c_w/self.p2), self.p2, self.p2, int(complement.shape[2]/self.p2/self.p2)) + complement = complement.permute(0, 5, 1, 3, 2, 4) + complement = complement.reshape(b, -1, c_h, c_w) + + complement = self.tail2(complement) + + return x, complement + + +def build_model(args): + return CrossCMMT(args) + + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mtrans_transformer.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mtrans_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..61314a6a81247b0740a1e98d078242b26ce73f90 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/mtrans_transformer.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(args, embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=args.MODEL.TRANSFORMER_DEPTH, + num_heads=args.MODEL.TRANSFORMER_NUM_HEADS, + mlp_ratio=args.MODEL.TRANSFORMER_MLP_RATIO) + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/munet_concat_decomp.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/munet_concat_decomp.py new file mode 100644 index 0000000000000000000000000000000000000000..07c6f64804dd586927814801b167163f0b65f58f --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/munet_concat_decomp.py @@ -0,0 +1,295 @@ +""" +Xiaohan Xing, 2023/06/25 +两个模态分别用CNN提取多个层级的特征, 每个层级的特征都经过concat之后得到fused feature, +然后每个模态都通过一层conv变换得到2C*h*w的特征, 其中C个channels作为common feature, 另C个channels作为specific features. +利用CDDFuse论文中的decomposition loss约束特征解耦。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.transform_layers1 = nn.ModuleList() + self.transform_layers2 = nn.ModuleList() + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.transform_layers1.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.transform_layers2.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + stack1_common, stack1_specific, stack2_common, stack2_specific = [], [], [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_transform_layer, net2_transform_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.transform_layers1, self.transform_layers2, \ + self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + # print("original feature:", output1.shape) + + ### 分别提取两个模态的common feature和specific feature + output1 = net1_transform_layer(output1) + # print("transformed feature:", output1.shape) + output1_common = output1[:, :(output1.shape[1]//2), :, :] + output1_specific = output1[:, (output1.shape[1]//2):, :, :] + output2 = net2_transform_layer(output2) + output2_common = output2[:, :(output2.shape[1]//2), :, :] + output2_specific = output2[:, (output2.shape[1]//2):, :, :] + + stack1_common.append(output1_common) + stack1_specific.append(output1_specific) + stack2_common.append(output2_common) + stack2_specific.append(output2_specific) + + ### 将一个模态的两组feature和另一模态的common feature合并, 然后经过conv变换得到和初始特征相同维度的fused feature送到下一层。 + output1 = net1_fuse_layer(torch.cat((output1_common, output1_specific, output2_common), 1)) + output2 = net2_fuse_layer(torch.cat((output2_common, output2_specific, output1_common), 1)) + # print("fused feature:", output1.shape) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, stack1_common, stack1_specific, stack2_common, stack2_specific + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/munet_multi_concat.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/munet_multi_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..9e90373e1134210809b98fdc64e3d51d76123855 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/munet_multi_concat.py @@ -0,0 +1,306 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + # output1, output2 = image1, image2 + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # relation1 = get_relation_matrix(output1) + # relation2 = get_relation_matrix(output2) + # relation_stack1.append(relation1) + # relation_stack2.append(relation2) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + # output1 = net1_layer(output1) + # output2 = net2_layer(output2) + + # ### 两个模态的特征concat + # fused_output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + # fused_output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # stack1.append(fused_output1) + # stack2.append(fused_output2) + + # ### downsampling features + # output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + # output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/munet_multi_sum.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/munet_multi_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb4efb13b54ccbeaeb34fc51d3e72b0c50ee242 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/munet_multi_sum.py @@ -0,0 +1,270 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers): + + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + + ## 两个模态的特征求平均. + fused_output1 = (output1 + output2)/2.0 + fused_output2 = (output1 + output2)/2.0 + + output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features #, relation_stack1, relation_stack2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/munet_multi_transfuse.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/munet_multi_transfuse.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ce250a078ba0847c729c1ef1b76452f64d0f07 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/munet_multi_transfuse.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +2023/06/23 +每个层级的特征用TransFuse layer融合之后, 将两个模态的original features和transfuse features concat, +然后在每个模态中经过conv层变换得到下一层的输入。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .TransFuse import TransFuse_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_TransFuse(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + ### conv layer to transform the features after Transfuse layer. + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + self.n_anchors = 15 + self.avgpool = nn.AdaptiveAvgPool2d((self.n_anchors, self.n_anchors)) + self.transfuse_layer1 = TransFuse_layer(n_embd=self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer2 = TransFuse_layer(n_embd=2*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer3 = TransFuse_layer(n_embd=4*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer4 = TransFuse_layer(n_embd=8*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + + self.transfuse_layers = nn.Sequential(self.transfuse_layer1, self.transfuse_layer2, self.transfuse_layer3, self.transfuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, transfuse_layer, net1_fuse_layer, net2_fuse_layer in zip(self.encoder1.down_sample_layers, \ + self.encoder2.down_sample_layers, self.transfuse_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### extract multi-level features from the encoder of each modality. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + ### Transformer-based multi-modal fusion layer + feature1 = self.avgpool(output1) + feature2 = self.avgpool(output2) + feature1, feature2 = transfuse_layer(feature1, feature2) + feature1 = F.interpolate(feature1, scale_factor=output1.shape[-1]//self.n_anchors, mode='bilinear') + feature2 = F.interpolate(feature2, scale_factor=output2.shape[-1]//self.n_anchors, mode='bilinear') + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_TransFuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/munet_swinfusion.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/munet_swinfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6203c4331744f1da70f18e403360196a419b4cb5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/munet_swinfusion.py @@ -0,0 +1,268 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .SwinFuse_layer import SwinFusion_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_SwinFuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过SwinFusion的方式融合。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.fuse_type = 'swinfuse' + self.img_size = 240 + # self.fuse_type = 'sum' + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # Swin transformer fusion modules + + self.fuse_layer1 = SwinFusion_layer(img_size=(self.img_size//2, self.img_size//2), patch_size=4, window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer2 = SwinFusion_layer(img_size=(self.img_size//4, self.img_size//4), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=2*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer3 = SwinFusion_layer(img_size=(self.img_size//8, self.img_size//8), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=4*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer4 = SwinFusion_layer(img_size=(self.img_size//16, self.img_size//16), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=8*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.swinfusion_layers = nn.Sequential(self.fuse_layer1, self.fuse_layer2, self.fuse_layer3, self.fuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, swinfuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.swinfusion_layers): + output1 = net1_layer(output1) + stack1.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # print("output of encoders:", output1.shape, output2.shape) + # print("CNN feature range:", output1.max(), output2.min()) + + ### Swin transformer-based multi-modal fusion. + feature1, feature2 = swinfuse_layer(output1, output2) + output1 = (output1 + feature1 + output2 + feature2)/4.0 + output2 = (output1 + feature1 + output2 + feature2)/4.0 + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_SwinFuse(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/original_MINet.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/original_MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..5a115872320f677d8b34264eed95c22fff0df9c4 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/original_MINet.py @@ -0,0 +1,335 @@ +from fastmri.models import common +import torch +import torch.nn as nn +import torch.nn.functional as F + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=common.default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 2 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + common.Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.add_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class MINet(nn.Module): + def __init__(self, n_resgroups,n_resblocks, n_feats): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1,x2#x1=pd x2=pdfs + \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/restormer.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..3699cc809e438be4e1bd5efaba7d96e18d7f1602 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/restormer.py @@ -0,0 +1,307 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + # print("feature before FFN projection:", x.max(), x.min()) + x = self.project_out(x) + # print("FFN output:", x.max(), x.min()) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + b,c,h,w = x.shape + # print("Attention block input:", x.max(), x.min()) + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + # print("attention values:", attn) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("Attention block output:", out.max(), out.min()) + + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): + super(TransformerBlock, self).__init__() + + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + # dim = 32, + # num_blocks = [4,4,4,4], + dim = 32, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + + # stack = [] + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + # print("encoder block1 feature:", out_enc_level1.shape, out_enc_level1.max(), out_enc_level1.min()) + # stack.append(out_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + # print("encoder block2 feature:", out_enc_level2.shape, out_enc_level2.max(), out_enc_level2.min()) + # stack.append(out_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + # print("encoder block3 feature:", out_enc_level3.shape, out_enc_level3.max(), out_enc_level3.min()) + # stack.append(out_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + # print("encoder block4 feature:", latent.shape, latent.max(), latent.min()) + # stack.append(latent) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + + return out_dec_level1 #, stack + + + +def build_model(args): + return Restormer() diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/restormer_block.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/restormer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..170ae6a826e5a3e356cb48f8c89500c2d5242816 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/restormer_block.py @@ -0,0 +1,321 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + # print("input to the LayerNorm layer:", x.max(), x.min()) + h, w = x.shape[-2:] + x = to_4d(self.body(to_3d(x)), h, w) + # print("output of the LayerNorm:", x.max(), x.min()) + return x + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + # self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.temperature = 1.0 / ((dim//num_heads) ** 0.5) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + # print("input to the Attention layer:", x.max(), x.min()) + + b,c,h,w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + # print("q range:", q.max(), q.min()) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + # print("scaling parameter:", self.temperature) + # print("attention matrix before softmax:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + attn = attn.softmax(dim=-1) + # print("attention matrix range:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + # print("v range:", v.max(), v.min(), v.mean()) + + out = (attn @ v) + # print("self-attention output:", out.max(), out.min()) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("output of attention layer:", out.max(), out.min()) + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, pre_norm): + super(TransformerBlock, self).__init__() + + self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + self.pre_norm = pre_norm + + def forward(self, x): + x = self.conv(x) + # print("restormer block input:", x.max(), x.min()) + if self.pre_norm: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + else: + x = self.norm1(x + self.attn(x)) + x = self.norm2(x + self.ffn(x)) + # x = self.norm1(self.attn(x)) + # x = self.norm2(self.ffn(x)) + # x = x + self.attn(x) + # x = x + self.ffn(x) + # print("restormer block output:", x.max(), x.min()) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + dim = 48, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, inp_img): + + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + + + return out_dec_level1 + + + + +if __name__ == '__main__': + model = Restormer() + # print(model) + + A = torch.randn((1, 1, 240, 240)) + x = model(A) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/swinIR.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/swinIR.py new file mode 100644 index 0000000000000000000000000000000000000000..5cbfc9c175c02531ac80a05ba4abfc0776f752dd --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/swinIR.py @@ -0,0 +1,874 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 4, 4, 6], + embed_dim=32, num_heads=[6, 6, 6, 6], mlp_ratio=4, upsampler='pixelshuffledirect') + + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = SwinIR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/swin_transformer.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..877bdfbffc91012e4ac31f41b838ebb116f4a825 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/swin_transformer.py @@ -0,0 +1,878 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=1, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + # print("shallow feature:", x.shape) + x = self.conv_after_body(self.forward_features(x)) + x + # print("deep feature:", x.shape) + x = self.upsample(x) + # print("after upsample:", x.shape) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + # return SwinIR(upscale=1, img_size=(240, 240), + # window_size=8, img_range=1., depths=[6, 6, 6, 6], + # embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + +if __name__ == '__main__': + height = 240 + width = 240 + window_size = 8 + model = SwinIR(upscale=1, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + x = torch.randn((1, 1, height, width)) + print("input:", x.shape) + x = model(x) + print("output:", x.shape) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/trans_unet.zip b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/trans_unet.zip new file mode 100644 index 0000000000000000000000000000000000000000..e7b71c1287551fd6119efd7c62da2196307998f5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/trans_unet.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:369de2a28036532c446409ee1f7b953d5763641ffd452675613e1bb9bba89eae +size 19354 diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/trans_unet/trans_unet.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/trans_unet/trans_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..fa61dbadfcaee9b26594307adc90cdd1d2a84e3e --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/trans_unet/trans_unet.py @@ -0,0 +1,489 @@ +# coding=utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import logging +import math + +from os.path import join as pjoin + +import torch +import torch.nn as nn +import numpy as np + +from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm +from torch.nn.modules.utils import _pair +from scipy import ndimage +from . import vit_seg_configs as configs +# import .vit_seg_configs as configs +from .vit_seg_modeling_resnet_skip import ResNetV2 + + +logger = logging.getLogger(__name__) + + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class Attention(nn.Module): + def __init__(self, config, vis): + super(Attention, self).__init__() + self.vis = vis + self.num_attention_heads = config.transformer["num_heads"] + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(config.hidden_size, self.all_head_size) + self.key = Linear(config.hidden_size, self.all_head_size) + self.value = Linear(config.hidden_size, self.all_head_size) + + self.out = Linear(config.hidden_size, config.hidden_size) + self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) + self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + + +class Mlp(nn.Module): + def __init__(self, config): + super(Mlp, self).__init__() + self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) + self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(config.transformer["dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings. + """ + def __init__(self, config, img_size, in_channels=3): + super(Embeddings, self).__init__() + self.hybrid = None + self.config = config + img_size = _pair(img_size) + + if config.patches.get("grid") is not None: # ResNet + grid_size = config.patches["grid"] + # print("image size:", img_size, "grid size:", grid_size) + patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) + patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) + # print("patch_size_real", patch_size_real) + n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) + self.hybrid = True + else: + patch_size = _pair(config.patches["size"]) + n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) + self.hybrid = False + + if self.hybrid: + self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) + in_channels = self.hybrid_model.width * 16 + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=config.hidden_size, + kernel_size=patch_size, + stride=patch_size) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) + + self.dropout = Dropout(config.transformer["dropout_rate"]) + + + def forward(self, x): + # print("input x:", x.shape) ### (4, 3, 240, 240) + if self.hybrid: + x, features = self.hybrid_model(x) + else: + features = None + # print("output of the hybrid model:", x.shape) ### (4, 1024, 15, 15) + x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) + x = x.flatten(2) + x = x.transpose(-1, -2) # (B, n_patches, hidden) + + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings, features + + +class Block(nn.Module): + def __init__(self, config, vis): + super(Block, self).__init__() + self.hidden_size = config.hidden_size + self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn = Mlp(config) + self.attn = Attention(config, vis) + + def forward(self, x): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + def load_from(self, weights, n_block): + ROOT = f"Transformer/encoderblock_{n_block}" + with torch.no_grad(): + query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() + key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() + value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() + out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() + + query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) + key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) + value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) + out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) + + self.attn.query.weight.copy_(query_weight) + self.attn.key.weight.copy_(key_weight) + self.attn.value.weight.copy_(value_weight) + self.attn.out.weight.copy_(out_weight) + self.attn.query.bias.copy_(query_bias) + self.attn.key.bias.copy_(key_bias) + self.attn.value.bias.copy_(value_bias) + self.attn.out.bias.copy_(out_bias) + + mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() + mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() + mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() + mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() + + self.ffn.fc1.weight.copy_(mlp_weight_0) + self.ffn.fc2.weight.copy_(mlp_weight_1) + self.ffn.fc1.bias.copy_(mlp_bias_0) + self.ffn.fc2.bias.copy_(mlp_bias_1) + + self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) + self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) + self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) + self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) + + +class Encoder(nn.Module): + def __init__(self, config, vis): + super(Encoder, self).__init__() + self.vis = vis + self.layer = nn.ModuleList() + self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) + for _ in range(config.transformer["num_layers"]): + layer = Block(config, vis) + self.layer.append(copy.deepcopy(layer)) + + def forward(self, hidden_states): + attn_weights = [] + for layer_block in self.layer: + hidden_states, weights = layer_block(hidden_states) + if self.vis: + attn_weights.append(weights) + encoded = self.encoder_norm(hidden_states) + return encoded, attn_weights + + +class Transformer(nn.Module): + def __init__(self, config, img_size, vis): + super(Transformer, self).__init__() + self.embeddings = Embeddings(config, img_size=img_size) + self.encoder = Encoder(config, vis) + # print(self.encoder) + + def forward(self, input_ids): + embedding_output, features = self.embeddings(input_ids) ### "features" store the features from different layers. + # print("embedding_output:", embedding_output.shape) ## (B, n_patch, hidden) + encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) + # print("encoded feature:", encoded.shape) + return encoded, attn_weights, features + + # return embedding_output, features + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + bn = nn.BatchNorm2d(out_channels) + + super(Conv2dReLU, self).__init__(conv, bn, relu) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels=0, + use_batchnorm=True, + ): + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + + def forward(self, x, skip=None): + x = self.up(x) + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class ReconstructionHead(nn.Sequential): + + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() + super().__init__(conv2d, upsampling) + + +class DecoderCup(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + head_channels = 512 + self.conv_more = Conv2dReLU( + config.hidden_size, + head_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + decoder_channels = config.decoder_channels + in_channels = [head_channels] + list(decoder_channels[:-1]) + out_channels = decoder_channels + + if self.config.n_skip != 0: + skip_channels = self.config.skip_channels + for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip + skip_channels[3-i]=0 + + else: + skip_channels=[0,0,0,0] + + blocks = [ + DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) + ] + self.blocks = nn.ModuleList(blocks) + # print(self.blocks) + + def forward(self, hidden_states, features=None): + B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) + h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) + x = hidden_states.permute(0, 2, 1) + x = x.contiguous().view(B, hidden, h, w) + x = self.conv_more(x) + for i, decoder_block in enumerate(self.blocks): + if features is not None: + skip = features[i] if (i < self.config.n_skip) else None + else: + skip = None + x = decoder_block(x, skip=skip) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): + super(VisionTransformer, self).__init__() + self.num_classes = num_classes + self.zero_head = zero_head + self.classifier = config.classifier + self.transformer = Transformer(config, img_size, vis) + self.decoder = DecoderCup(config) + self.segmentation_head = ReconstructionHead( + in_channels=config['decoder_channels'][-1], + out_channels=config['n_classes'], + kernel_size=3, + ) + self.activation = nn.Tanh() + self.config = config + + def forward(self, x): + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + ### hybrid CNN and transformer to extract features from the input images. + x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) + + # ### extract features with CNN only. + # x, features = self.transformer(x) # (B, n_patch, hidden) + + x = self.decoder(x, features) + # logits = self.segmentation_head(x) + logits = self.activation(self.segmentation_head(x)) + # print("logits:", logits.shape) + return logits + + def load_from(self, weights): + with torch.no_grad(): + + res_weight = weights + self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) + self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) + + self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) + self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) + + posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) + + posemb_new = self.transformer.embeddings.position_embeddings + if posemb.size() == posemb_new.size(): + self.transformer.embeddings.position_embeddings.copy_(posemb) + elif posemb.size()[1]-1 == posemb_new.size()[1]: + posemb = posemb[:, 1:] + self.transformer.embeddings.position_embeddings.copy_(posemb) + else: + logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) + ntok_new = posemb_new.size(1) + if self.classifier == "seg": + _, posemb_grid = posemb[:, :1], posemb[0, 1:] + gs_old = int(np.sqrt(len(posemb_grid))) + gs_new = int(np.sqrt(ntok_new)) + print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb = posemb_grid + self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) + + # Encoder whole + for bname, block in self.transformer.encoder.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=uname) + + if self.transformer.embeddings.hybrid: + self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) + gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) + gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) + self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) + self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) + + for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(res_weight, n_block=bname, n_unit=uname) + +CONFIGS = { + 'ViT-B_16': configs.get_b16_config(), + 'ViT-B_32': configs.get_b32_config(), + 'ViT-L_16': configs.get_l16_config(), + 'ViT-L_32': configs.get_l32_config(), + 'ViT-H_14': configs.get_h14_config(), + 'R50-ViT-B_16': configs.get_r50_b16_config(), + 'R50-ViT-L_16': configs.get_r50_l16_config(), + 'testing': configs.get_testing(), +} + + + +def build_model(args): + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + print(net) + # net.load_from(weights=np.load(config_vit.pretrained_path)) + return net + + +if __name__ == "__main__": + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + net.load_from(weights=np.load(config_vit.pretrained_path)) + + input_data = torch.rand(4, 1, 240, 240).cuda() + print("input:", input_data.shape) + output = net(input_data) + print("output:", output.shape) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/trans_unet/vit_seg_configs.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/trans_unet/vit_seg_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..bed7d3346e30328a38ac6fef1621f788d905df1e --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/trans_unet/vit_seg_configs.py @@ -0,0 +1,132 @@ +import ml_collections + +def get_b16_config(): + """Returns the ViT-B/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 768 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 3072 + config.transformer.num_heads = 12 + config.transformer.num_layers = 4 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + + config.classifier = 'seg' + config.representation_size = None + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' + config.patch_size = 16 + + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_testing(): + """Returns a minimal configuration for testing.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 1 + config.transformer.num_heads = 1 + config.transformer.num_layers = 1 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config + + +def get_r50_b16_config(): + """Returns the Resnet50 + ViT-B/16 configuration.""" + config = get_b16_config() + # config.patches.grid = (16, 16) + config.patches.grid = (15, 15) ### for the input image with size (240, 240) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'recon' + config.pretrained_path = '/home/xiaohan/workspace/MSL_MRI/code/pretrained_model/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 1 + config.n_skip = 3 + config.activation = 'tanh' + + return config + + +def get_b32_config(): + """Returns the ViT-B/32 configuration.""" + config = get_b16_config() + config.patches.size = (32, 32) + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' + return config + + +def get_l16_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1024 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 4096 + config.transformer.num_heads = 16 + config.transformer.num_layers = 24 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.representation_size = None + + # custom + config.classifier = 'seg' + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_r50_l16_config(): + """Returns the Resnet50 + ViT-L/16 configuration. customized """ + config = get_l16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_l32_config(): + """Returns the ViT-L/32 configuration.""" + config = get_l16_config() + config.patches.size = (32, 32) + return config + + +def get_h14_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (14, 14)}) + config.hidden_size = 1280 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 5120 + config.transformer.num_heads = 16 + config.transformer.num_layers = 32 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + + return config \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae80ef7d96c46cb91bda849901aef34008e62ed --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py @@ -0,0 +1,160 @@ +import math + +from os.path import join as pjoin +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +class StdConv2d(nn.Conv2d): + + def forward(self, x): + w = self.weight + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / torch.sqrt(v + 1e-5) + return F.conv2d(x, w, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +def conv3x3(cin, cout, stride=1, groups=1, bias=False): + return StdConv2d(cin, cout, kernel_size=3, stride=stride, + padding=1, bias=bias, groups=groups) + + +def conv1x1(cin, cout, stride=1, bias=False): + return StdConv2d(cin, cout, kernel_size=1, stride=stride, + padding=0, bias=bias) + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + """ + + def __init__(self, cin, cout=None, cmid=None, stride=1): + super().__init__() + cout = cout or cin + cmid = cmid or cout//4 + + self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv1 = conv1x1(cin, cmid, bias=False) + self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! + self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) + self.conv3 = conv1x1(cmid, cout, bias=False) + self.relu = nn.ReLU(inplace=True) + + if (stride != 1 or cin != cout): + # Projection also with pre-activation according to paper. + self.downsample = conv1x1(cin, cout, stride, bias=False) + self.gn_proj = nn.GroupNorm(cout, cout) + + def forward(self, x): + + # Residual branch + residual = x + if hasattr(self, 'downsample'): + residual = self.downsample(x) + residual = self.gn_proj(residual) + + # Unit's branch + y = self.relu(self.gn1(self.conv1(x))) + y = self.relu(self.gn2(self.conv2(y))) + y = self.gn3(self.conv3(y)) + + y = self.relu(residual + y) + return y + + def load_from(self, weights, n_block, n_unit): + conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) + conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) + conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) + + gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) + gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) + + gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) + gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) + + gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) + gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) + + self.conv1.weight.copy_(conv1_weight) + self.conv2.weight.copy_(conv2_weight) + self.conv3.weight.copy_(conv3_weight) + + self.gn1.weight.copy_(gn1_weight.view(-1)) + self.gn1.bias.copy_(gn1_bias.view(-1)) + + self.gn2.weight.copy_(gn2_weight.view(-1)) + self.gn2.bias.copy_(gn2_bias.view(-1)) + + self.gn3.weight.copy_(gn3_weight.view(-1)) + self.gn3.bias.copy_(gn3_bias.view(-1)) + + if hasattr(self, 'downsample'): + proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) + proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) + proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) + + self.downsample.weight.copy_(proj_conv_weight) + self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) + self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode.""" + + def __init__(self, block_units, width_factor): + super().__init__() + width = int(64 * width_factor) + self.width = width + + self.root = nn.Sequential(OrderedDict([ + ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), + ('gn', nn.GroupNorm(32, width, eps=1e-6)), + ('relu', nn.ReLU(inplace=True)), + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) + ])) + + self.body = nn.Sequential(OrderedDict([ + ('block1', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], + ))), + ('block2', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], + ))), + ('block3', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], + ))), + ])) + + def forward(self, x): + features = [] + b, c, in_size, _ = x.size() + x = self.root(x) + features.append(x) + x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) + for i in range(len(self.body)-1): + x = self.body[i](x) + right_size = int(in_size / 4 / (i+1)) + if x.size()[2] != right_size: + pad = right_size - x.size()[2] + assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) + feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) + feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] + else: + feat = x + features.append(feat) + x = self.body[-1](x) + return x, features[::-1] \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/transformer_modules.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/transformer_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..900042884305619e39ad9b792b3b1936dfbd37b3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/transformer_modules.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=4, + num_heads=4, + mlp_ratio=3) + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/unet.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..d6dd935cd34a121d8079a8f5184d4ea29e15c8da --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/unet.py @@ -0,0 +1,196 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F + + +class Unet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = image + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + # print("unet feature range:", output.max(), output.min()) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/unet_restormer.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/unet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f78aa5f192957b2706c6f691dfc66c427b65ae --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/unet_restormer.py @@ -0,0 +1,234 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + # print("Unet feature range:", output.max(), output.min()) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/unet_transformer.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/unet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..982fe23cec321de6b636e04c51b62037054ea432 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/unet_transformer.py @@ -0,0 +1,346 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +对于Unet中的high-level feature, 用Transformer进行特征交互,然后和Transformer之前的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Unet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = image + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack, output = self.encoder(output) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output) # [4, 225, 256], 225 is the number of patches. + patch_embed_input = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1])), int(np.sqrt(feature_output.shape[1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[-1], h, w) # [4,512*2,24,24] + + output = torch.cat((output, feature_output), 1) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output = self.decoder(output, stack) + + return output + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Transformer(args) diff --git a/MRI_recon/code/Frequency-Diffusion/networks/compare_models/unimodal_transformer.py b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/unimodal_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c99214e10bee2f7a1be49f8560799ffbc14ed0a5 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/compare_models/unimodal_transformer.py @@ -0,0 +1,112 @@ +import torch + +from torch import nn +from .transformer_modules import build_transformer +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(): + return ReconstructionHead(input_dim=1, hidden_dim=12) + + + +class CMMT(nn.Module): + + def __init__(self, args): + super(CMMT, self).__init__() + + self.head = build_head() + + HEAD_HIDDEN_DIM = 12 + PATCH_SIZE = 16 + INPUT_SIZE = 240 + OUTPUT_DIM = 1 + + patch_dim = HEAD_HIDDEN_DIM* PATCH_SIZE ** 2 + num_patches = (INPUT_SIZE // PATCH_SIZE) ** 2 + + self.transformer = build_transformer(patch_dim) + + self.patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=PATCH_SIZE, p2=PATCH_SIZE), + ) + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, patch_dim)) + + + self.p1 = PATCH_SIZE + self.p2 = PATCH_SIZE + + self.tail = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x): + + b,_ , h, w = x.shape + + x = self.head(x) + + x= self.patch_embbeding(x) + + x += self.pos_embedding + + x = self.transformer(x) # b HW p1p2c + + c = int(x.shape[2]/(self.p1*self.p2)) + H = int(h/self.p1) + W = int(w/self.p2) + + x = x.reshape(b, H, W, self.p1, self.p2, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w,) + x = self.tail(x) + + return x + + +def build_model(args): + return CMMT(args) + + + + + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/frequency_model/__init__.py b/MRI_recon/code/Frequency-Diffusion/networks/frequency_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/code/Frequency-Diffusion/networks/frequency_model/blocks.py b/MRI_recon/code/Frequency-Diffusion/networks/frequency_model/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..bad92a6d6709130c4ad1cc49d5094958d44d7e71 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/frequency_model/blocks.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + +USE_PYTORCH_IN = False + + +###################################################################### +# Superclass of all Modules that take two inputs +###################################################################### +class TwoInputModule(nn.Module): + def forward(self, input1, input2): + raise NotImplementedError + +###################################################################### +# A (sort of) hacky way to create a module that takes two inputs (e.g. x and z) +# and returns one output (say o) defined as follows: +# o = module2.forward(module1.forward(x), z) +# Note that module2 MUST support two inputs as well. +###################################################################### +class MergeModule(TwoInputModule): + def __init__(self, module1, module2): + """ module1 could be any module (e.g. Sequential of several modules) + module2 must accept two inputs + """ + super(MergeModule, self).__init__() + self.module1 = module1 + self.module2 = module2 + + def forward(self, input1, input2): + output1 = self.module1.forward(input1) + output2 = self.module2.forward(output1, input2) + return output2 + +###################################################################### +# A (sort of) hacky way to create a container that takes two inputs (e.g. x and z) +# and applies a sequence of modules (exactly like nn.Sequential) but MergeModule +# is one of its submodules it applies it to both inputs +###################################################################### +class TwoInputSequential(nn.Sequential, TwoInputModule): + def __init__(self, *args): + super(TwoInputSequential, self).__init__(*args) + + def forward(self, input1, input2): + """overloads forward function in parent calss""" + + for module in self._modules.values(): + if isinstance(module, TwoInputModule): + input1 = module.forward(input1, input2) + else: + input1 = module.forward(input1) + return input1 + + +###################################################################### +# A standard instance norm module. +# Since the pytorch instance norm used BatchNorm as a base and thus is +# different from the standard implementation. +###################################################################### +class InstanceNorm(nn.Module): + def __init__(self, num_features, affine=True, eps=1e-5): + """`num_features` number of feature channels + """ + super(InstanceNorm, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + + self.scale = Parameter(torch.Tensor(num_features)) + self.shift = Parameter(torch.Tensor(num_features)) + + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + self.scale.data.normal_(mean=0., std=0.02) + self.shift.data.zero_() + + def forward(self, input): + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + centered_x = x_reshaped - mean + std = torch.rsqrt((centered_x ** 2).mean(2, keepdim=True) + self.eps) + norm_features = (centered_x * std).view(*size) + + # broadcast on the batch dimension, hight and width dimensions + if self.affine: + output = norm_features * self.scale[:,None,None] + self.shift[:,None,None] + else: + output = norm_features + + return output + +InstanceNorm2d = nn.InstanceNorm2d if USE_PYTORCH_IN else InstanceNorm + +###################################################################### +# A module implementing conditional instance norm. +# Takes two inputs: x (input features) and z (latent codes) +###################################################################### +class CondInstanceNorm(TwoInputModule): + def __init__(self, x_dim, z_dim, eps=1e-5): + """`x_dim` dimensionality of x input + `z_dim` dimensionality of z latents + """ + super(CondInstanceNorm, self).__init__() + self.eps = eps + self.shift_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + self.scale_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + + def forward(self, input, noise): + + shift = self.shift_conv.forward(noise) + scale = self.scale_conv.forward(noise) + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + var = x_reshaped.var(2, keepdim=True) + std = torch.rsqrt(var + self.eps) + norm_features = ((x_reshaped - mean) * std).view(*size) + output = norm_features * scale + shift + return output + + +###################################################################### +# A modified resnet block which allows for passing additional noise input +# to be used for conditional instance norm +###################################################################### +class CINResnetBlock(TwoInputModule): + def __init__(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + super(CINResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + for idx, module in enumerate(self.conv_block): + self.add_module(str(idx), module) + + def build_conv_block(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [ + MergeModule( + nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(x_dim, z_dim) + ), + nn.ReLU(True) + ] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + InstanceNorm2d(x_dim, affine=True)] + + return TwoInputSequential(*conv_block) + + def forward(self, x, noise): + out = self.conv_block(x, noise) + out = self.relu(x + out) + return out + +###################################################################### +# Define a resnet block +###################################################################### +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [nn.ReLU(True)] + + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = self.conv_block(x) + out = self.relu(x + out) + return out diff --git a/MRI_recon/code/Frequency-Diffusion/networks/frequency_model/common_freq.py b/MRI_recon/code/Frequency-Diffusion/networks/frequency_model/common_freq.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6c78467db1391f6475d069dafda7f295b97ae1 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/frequency_model/common_freq.py @@ -0,0 +1,391 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange + + +class Scale(nn.Module): + + def __init__(self, init_value=1e-3): + super().__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale + + +class invPixelShuffle(nn.Module): + def __init__(self, ratio=2): + super(invPixelShuffle, self).__init__() + self.ratio = ratio + + def forward(self, tensor): + ratio = self.ratio + b = tensor.size(0) + ch = tensor.size(1) + y = tensor.size(2) + x = tensor.size(3) + assert x % ratio == 0 and y % ratio == 0, 'x, y, ratio : {}, {}, {}'.format(x, y, ratio) + tensor = tensor.view(b, ch, y // ratio, ratio, x // ratio, ratio).permute(0, 1, 3, 5, 2, 4) + return tensor.contiguous().view(b, -1, y // ratio, x // ratio) + + +class UpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(in_channels=n_feats, out_channels=4 * n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PixelShuffle(upscale_factor=2)) + m.append(nn.PReLU()) + super(UpSampler, self).__init__(*m) + + +class InvUpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(invPixelShuffle(2)) + m.append(nn.Conv2d(in_channels=n_feats * 4, out_channels=n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PReLU()) + super(InvUpSampler, self).__init__(*m) + + + +class AdaptiveNorm(nn.Module): + def __init__(self, n): + super(AdaptiveNorm, self).__init__() + + self.w_0 = nn.Parameter(torch.Tensor([1.0])) + self.w_1 = nn.Parameter(torch.Tensor([0.0])) + + self.bn = nn.BatchNorm2d(n, momentum=0.999, eps=0.001) + + def forward(self, x): + return self.w_0 * x + self.w_1 * self.bn(x) + + +class ConvBNReLU2D(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, + bias=False, act=None, norm=None): + super(ConvBNReLU2D, self).__init__() + + self.layers = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + self.act = None + self.norm = None + if norm == 'BN': + self.norm = torch.nn.BatchNorm2d(out_channels) + elif norm == 'IN': + self.norm = torch.nn.InstanceNorm2d(out_channels) + elif norm == 'GN': + self.norm = torch.nn.GroupNorm(2, out_channels) + elif norm == 'WN': + self.layers = torch.nn.utils.weight_norm(self.layers) + elif norm == 'Adaptive': + self.norm = AdaptiveNorm(n=out_channels) + + if act == 'PReLU': + self.act = torch.nn.PReLU() + elif act == 'SELU': + self.act = torch.nn.SELU(True) + elif act == 'LeakyReLU': + self.act = torch.nn.LeakyReLU(negative_slope=0.02, inplace=True) + elif act == 'ELU': + self.act = torch.nn.ELU(inplace=True) + elif act == 'ReLU': + self.act = torch.nn.ReLU(True) + elif act == 'Tanh': + self.act = torch.nn.Tanh() + elif act == 'Sigmoid': + self.act = torch.nn.Sigmoid() + elif act == 'SoftMax': + self.act = torch.nn.Softmax2d() + + def forward(self, inputs): + + out = self.layers(inputs) + + if self.norm is not None: + out = self.norm(out) + if self.act is not None: + out = self.act(out) + return out + + +class ResBlock(nn.Module): + def __init__(self, num_features): + super(ResBlock, self).__init__() + self.layers = nn.Sequential( + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, act='ReLU', padding=1), + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, padding=1) + ) + + def forward(self, inputs): + return F.relu(self.layers(inputs) + inputs) + + +class DownSample(nn.Module): + def __init__(self, num_features, act, norm, scale=2): + super(DownSample, self).__init__() + if scale == 1: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + else: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + invPixelShuffle(ratio=scale), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + + def forward(self, inputs): + return self.layers(inputs) + + +class ResidualGroup(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act,norm, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [ + ResBlock(n_feat) for _ in range(n_resblocks)] + + modules_body.append(ConvBNReLU2D(n_feat, n_feat, kernel_size, padding=1, act=act, norm=norm)) + self.body = nn.Sequential(*modules_body) + self.re_scale = Scale(1) + + def forward(self, x): + res = self.body(x) + return res + self.re_scale(x) + + + +class FreBlock9(nn.Module): + def __init__(self, channels): + super(FreBlock9, self).__init__() + + self.fpre = nn.Conv2d(channels, channels, 1, 1, 0) + self.amp_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.pha_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.post = nn.Conv2d(channels, channels, 1, 1, 0) + + + def forward(self, x): + # print("x: ", x.shape) + _, _, H, W = x.shape + msF = torch.fft.rfft2(self.fpre(x)+1e-8, norm='backward') + + msF_amp = torch.abs(msF) + msF_pha = torch.angle(msF) + # print("msf_amp: ", msF_amp.shape) + amp_fuse = self.amp_fuse(msF_amp) + # print(amp_fuse.shape, msF_amp.shape) + amp_fuse = amp_fuse + msF_amp + pha_fuse = self.pha_fuse(msF_pha) + pha_fuse = pha_fuse + msF_pha + + real = amp_fuse * torch.cos(pha_fuse)+1e-8 + imag = amp_fuse * torch.sin(pha_fuse)+1e-8 + out = torch.complex(real, imag)+1e-8 + out = torch.abs(torch.fft.irfft2(out, s=(H, W), norm='backward')) + out = self.post(out) + out = out + x + out = torch.nan_to_num(out, nan=1e-5, posinf=1e-5, neginf=1e-5) + # print("out: ", out.shape) + return out + +class Attention(nn.Module): + def __init__(self, dim=64, num_heads=8, bias=False): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) + self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias) + self.q = nn.Conv2d(dim, dim , kernel_size=1, bias=bias) + self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x, y): + b, c, h, w = x.shape + + kv = self.kv_dwconv(self.kv(y)) + k, v = kv.chunk(2, dim=1) + q = self.q_dwconv(self.q(x)) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + +class FuseBlock7(nn.Module): + def __init__(self, channels): + super(FuseBlock7, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fre_att = Attention(dim=channels) + self.spa_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + ori = spa + fre = self.fre(fre) + spa = self.spa(spa) + fre = self.fre_att(fre, spa)+fre + spa = self.fre_att(spa, fre)+spa + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class FuseBlock6(nn.Module): + def __init__(self, channels): + super(FuseBlock6, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock7(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock7, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t1_att = Attention(dim=channels) + self.t2_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + t1 = self.t1_att(t1, t2)+t1 + t2 = self.t2_att(t2, t1)+t2 + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock6(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock6, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/frequency_model/model.py b/MRI_recon/code/Frequency-Diffusion/networks/frequency_model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..2fd1e66946b7305f6e2228c92dab23da2c545dfe --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/frequency_model/model.py @@ -0,0 +1,376 @@ +import torch +from torch import nn +from . import common_freq as common + + +class TwoBranch(nn.Module): + def __init__(self, num_features, act, base_num_every_group, num_channels): + super(TwoBranch, self).__init__() + + self.num_features = num_features + self.act = act + self.num_channels = num_channels + + num_group = 4 + num_every_group = base_num_every_group + + self.init_T2_frq_branch() + self.init_T2_spa_branch( num_every_group) + self.init_T2_fre_spa_fusion() + + self.init_T1_frq_branch() + self.init_T1_spa_branch( num_every_group) + + self.init_modality_fre_fusion() + self.init_modality_spa_fusion() + + + def init_T2_frq_branch(self, ): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head_fre = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + + self.down1_fre = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo = nn.Sequential(common.FreBlock9(self.num_features)) + + modules_down2_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down2_fre = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo = nn.Sequential(common.FreBlock9(self.num_features)) + + modules_down3_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down3_fre = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo = nn.Sequential(common.FreBlock9(self.num_features)) + + modules_neck_fre = [common.FreBlock9(self.num_features, ) + ] + self.neck_fre = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo = nn.Sequential(common.FreBlock9(self.num_features )) + + modules_up1_fre = [common.UpSampler(2, self.num_features), + common.FreBlock9(self.num_features, ) + ] + self.up1_fre = nn.Sequential(*modules_up1_fre) + self.up1_fre_mo = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_up2_fre = [common.UpSampler(2, self.num_features), + common.FreBlock9(self.num_features, ) + ] + self.up2_fre = nn.Sequential(*modules_up2_fre) + self.up2_fre_mo = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_up3_fre = [common.UpSampler(2, self.num_features), + common.FreBlock9(self.num_features, ) + ] + self.up3_fre = nn.Sequential(*modules_up3_fre) + self.up3_fre_mo = nn.Sequential(common.FreBlock9(self.num_features, )) + + # define tail module + modules_tail_fre = [ + common.ConvBNReLU2D(self.num_features, out_channels=self.num_channels, kernel_size=3, padding=1, + act=self.act)] + self.tail_fre = nn.Sequential(*modules_tail_fre) + + def init_T2_spa_branch(self, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down1 = nn.Sequential(*modules_down1) + + + self.down1_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down2 = nn.Sequential(*modules_down2) + + self.down2_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down3 = nn.Sequential(*modules_down3) + self.down3_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.neck = nn.Sequential(*modules_neck) + + self.neck_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_up1 = [common.UpSampler(2, self.num_features), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.up1 = nn.Sequential(*modules_up1) + + self.up1_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_up2 = [common.UpSampler(2, self.num_features), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.up2 = nn.Sequential(*modules_up2) + self.up2_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + + modules_up3 = [common.UpSampler(2, self.num_features), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.up3 = nn.Sequential(*modules_up3) + self.up3_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + # define tail module + modules_tail = [ + common.ConvBNReLU2D(self.num_features, out_channels=self.num_channels, kernel_size=3, padding=1, + act=self.act)] + + self.tail = nn.Sequential(*modules_tail) + + def init_T2_fre_spa_fusion(self, ): + ### T2 frq & spa fusion part + conv_fuse = [] + for i in range(14): + conv_fuse.append(common.FuseBlock7(self.num_features)) + self.conv_fuse = nn.Sequential(*conv_fuse) + + def init_T1_frq_branch(self, ): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head_fre_T1 = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + + self.down1_fre_T1 = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_down2_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down2_fre_T1 = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_down3_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down3_fre_T1 = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_neck_fre = [common.FreBlock9(self.num_features, ) + ] + self.neck_fre_T1 = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + def init_T1_spa_branch(self, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head_T1 = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down1_T1 = nn.Sequential(*modules_down1) + + + self.down1_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down2_T1 = nn.Sequential(*modules_down2) + + self.down2_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down3_T1 = nn.Sequential(*modules_down3) + self.down3_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.neck_T1 = nn.Sequential(*modules_neck) + + self.neck_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + + def init_modality_fre_fusion(self, ): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(self.num_features)) + self.conv_fuse_fre = nn.Sequential(*conv_fuse) + + def init_modality_spa_fusion(self, ): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(self.num_features)) + self.conv_fuse_spa = nn.Sequential(*conv_fuse) + + def forward(self, main, aux, t): + #### T1 fre encoder + t1_fre = self.head_fre_T1(aux) # 128 + + down1_fre_t1 = self.down1_fre_T1(t1_fre)# 64 + down1_fre_mo_t1 = self.down1_fre_mo_T1(down1_fre_t1) + + down2_fre_t1 = self.down2_fre_T1(down1_fre_mo_t1) # 32 + down2_fre_mo_t1 = self.down2_fre_mo_T1(down2_fre_t1) + + down3_fre_t1 = self.down3_fre_T1(down2_fre_mo_t1) # 16 + down3_fre_mo_t1 = self.down3_fre_mo_T1(down3_fre_t1) + + neck_fre_t1 = self.neck_fre_T1(down3_fre_mo_t1) # 16 + neck_fre_mo_t1 = self.neck_fre_mo_T1(neck_fre_t1) + + + #### T2 fre encoder and T1 & T2 fre fusion + x_fre = self.head_fre(main) # 128 + x_fre_fuse = self.conv_fuse_fre[0](t1_fre, x_fre) + + down1_fre = self.down1_fre(x_fre_fuse)# 64 + down1_fre_mo = self.down1_fre_mo(down1_fre) + down1_fre_mo_fuse = self.conv_fuse_fre[1](down1_fre_mo_t1, down1_fre_mo) + + down2_fre = self.down2_fre(down1_fre_mo_fuse) # 32 + down2_fre_mo = self.down2_fre_mo(down2_fre) + down2_fre_mo_fuse = self.conv_fuse_fre[2](down2_fre_mo_t1, down2_fre_mo) + + down3_fre = self.down3_fre(down2_fre_mo_fuse) # 16 + down3_fre_mo = self.down3_fre_mo(down3_fre) + down3_fre_mo_fuse = self.conv_fuse_fre[3](down3_fre_mo_t1, down3_fre_mo) + + neck_fre = self.neck_fre(down3_fre_mo_fuse) # 16 + neck_fre_mo = self.neck_fre_mo(neck_fre) + neck_fre_mo_fuse = self.conv_fuse_fre[4](neck_fre_mo_t1, neck_fre_mo) + + + #### T2 fre decoder + neck_fre_mo = neck_fre_mo_fuse + down3_fre_mo_fuse + + up1_fre = self.up1_fre(neck_fre_mo) # 32 + up1_fre_mo = self.up1_fre_mo(up1_fre) + up1_fre_mo = up1_fre_mo + down2_fre_mo_fuse + + up2_fre = self.up2_fre(up1_fre_mo) # 64 + up2_fre_mo = self.up2_fre_mo(up2_fre) + up2_fre_mo = up2_fre_mo + down1_fre_mo_fuse + + up3_fre = self.up3_fre(up2_fre_mo) # 128 + up3_fre_mo = self.up3_fre_mo(up3_fre) + up3_fre_mo = up3_fre_mo + x_fre_fuse + + res_fre = self.tail_fre(up3_fre_mo) + + #### T1 spa encoder + x_t1 = self.head_T1(aux) # 128 + + down1_t1 = self.down1_T1(x_t1) # 64 + down1_mo_t1 = self.down1_mo_T1(down1_t1) + + down2_t1 = self.down2_T1(down1_mo_t1) # 32 + down2_mo_t1 = self.down2_mo_T1(down2_t1) # 32 + + down3_t1 = self.down3_T1(down2_mo_t1) # 16 + down3_mo_t1 = self.down3_mo_T1(down3_t1) # 16 + + neck_t1 = self.neck_T1(down3_mo_t1) # 16 + neck_mo_t1 = self.neck_mo_T1(neck_t1) + + #### T2 spa encoder and fusion + x = self.head(main) # 128 + + x_fuse = self.conv_fuse_spa[0](x_t1, x) + down1 = self.down1(x_fuse) # 64 + down1_fuse = self.conv_fuse[0](down1_fre, down1) + down1_mo = self.down1_mo(down1_fuse) + down1_fuse_mo = self.conv_fuse[1](down1_fre_mo_fuse, down1_mo) + + down1_fuse_mo_fuse = self.conv_fuse_spa[1](down1_mo_t1, down1_fuse_mo) + down2 = self.down2(down1_fuse_mo_fuse) # 32 + down2_fuse = self.conv_fuse[2](down2_fre, down2) + down2_mo = self.down2_mo(down2_fuse) # 32 + down2_fuse_mo = self.conv_fuse[3](down2_fre_mo, down2_mo) + + down2_fuse_mo_fuse = self.conv_fuse_spa[2](down2_mo_t1, down2_fuse_mo) + down3 = self.down3(down2_fuse_mo_fuse) # 16 + down3_fuse = self.conv_fuse[4](down3_fre, down3) + down3_mo = self.down3_mo(down3_fuse) # 16 + down3_fuse_mo = self.conv_fuse[5](down3_fre_mo, down3_mo) + + down3_fuse_mo_fuse = self.conv_fuse_spa[3](down3_mo_t1, down3_fuse_mo) + neck = self.neck(down3_fuse_mo_fuse) # 16 + neck_fuse = self.conv_fuse[6](neck_fre, neck) + neck_mo = self.neck_mo(neck_fuse) + neck_mo = neck_mo + down3_mo + neck_fuse_mo = self.conv_fuse[7](neck_fre_mo, neck_mo) + + neck_fuse_mo_fuse = self.conv_fuse_spa[4](neck_mo_t1, neck_fuse_mo) + #### T2 spa decoder + up1 = self.up1(neck_fuse_mo_fuse) # 32 + up1_fuse = self.conv_fuse[8](up1_fre, up1) + up1_mo = self.up1_mo(up1_fuse) + up1_mo = up1_mo + down2_mo + up1_fuse_mo = self.conv_fuse[9](up1_fre_mo, up1_mo) + + up2 = self.up2(up1_fuse_mo) # 64 + up2_fuse = self.conv_fuse[10](up2_fre, up2) + up2_mo = self.up2_mo(up2_fuse) + up2_mo = up2_mo + down1_mo + up2_fuse_mo = self.conv_fuse[11](up2_fre_mo, up2_mo) + + up3 = self.up3(up2_fuse_mo) # 128 + + up3_fuse = self.conv_fuse[12](up3_fre, up3) + up3_mo = self.up3_mo(up3_fuse) + + up3_mo = up3_mo + x + up3_fuse_mo = self.conv_fuse[13](up3_fre_mo, up3_mo) + + res = self.tail(up3_fuse_mo) + + return res + main, res_fre + main + + + + + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/networks_fsm/__init__.py b/MRI_recon/code/Frequency-Diffusion/networks/networks_fsm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/code/Frequency-Diffusion/networks/networks_fsm/common_freq.py b/MRI_recon/code/Frequency-Diffusion/networks/networks_fsm/common_freq.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6c78467db1391f6475d069dafda7f295b97ae1 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/networks_fsm/common_freq.py @@ -0,0 +1,391 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange + + +class Scale(nn.Module): + + def __init__(self, init_value=1e-3): + super().__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale + + +class invPixelShuffle(nn.Module): + def __init__(self, ratio=2): + super(invPixelShuffle, self).__init__() + self.ratio = ratio + + def forward(self, tensor): + ratio = self.ratio + b = tensor.size(0) + ch = tensor.size(1) + y = tensor.size(2) + x = tensor.size(3) + assert x % ratio == 0 and y % ratio == 0, 'x, y, ratio : {}, {}, {}'.format(x, y, ratio) + tensor = tensor.view(b, ch, y // ratio, ratio, x // ratio, ratio).permute(0, 1, 3, 5, 2, 4) + return tensor.contiguous().view(b, -1, y // ratio, x // ratio) + + +class UpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(in_channels=n_feats, out_channels=4 * n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PixelShuffle(upscale_factor=2)) + m.append(nn.PReLU()) + super(UpSampler, self).__init__(*m) + + +class InvUpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(invPixelShuffle(2)) + m.append(nn.Conv2d(in_channels=n_feats * 4, out_channels=n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PReLU()) + super(InvUpSampler, self).__init__(*m) + + + +class AdaptiveNorm(nn.Module): + def __init__(self, n): + super(AdaptiveNorm, self).__init__() + + self.w_0 = nn.Parameter(torch.Tensor([1.0])) + self.w_1 = nn.Parameter(torch.Tensor([0.0])) + + self.bn = nn.BatchNorm2d(n, momentum=0.999, eps=0.001) + + def forward(self, x): + return self.w_0 * x + self.w_1 * self.bn(x) + + +class ConvBNReLU2D(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, + bias=False, act=None, norm=None): + super(ConvBNReLU2D, self).__init__() + + self.layers = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + self.act = None + self.norm = None + if norm == 'BN': + self.norm = torch.nn.BatchNorm2d(out_channels) + elif norm == 'IN': + self.norm = torch.nn.InstanceNorm2d(out_channels) + elif norm == 'GN': + self.norm = torch.nn.GroupNorm(2, out_channels) + elif norm == 'WN': + self.layers = torch.nn.utils.weight_norm(self.layers) + elif norm == 'Adaptive': + self.norm = AdaptiveNorm(n=out_channels) + + if act == 'PReLU': + self.act = torch.nn.PReLU() + elif act == 'SELU': + self.act = torch.nn.SELU(True) + elif act == 'LeakyReLU': + self.act = torch.nn.LeakyReLU(negative_slope=0.02, inplace=True) + elif act == 'ELU': + self.act = torch.nn.ELU(inplace=True) + elif act == 'ReLU': + self.act = torch.nn.ReLU(True) + elif act == 'Tanh': + self.act = torch.nn.Tanh() + elif act == 'Sigmoid': + self.act = torch.nn.Sigmoid() + elif act == 'SoftMax': + self.act = torch.nn.Softmax2d() + + def forward(self, inputs): + + out = self.layers(inputs) + + if self.norm is not None: + out = self.norm(out) + if self.act is not None: + out = self.act(out) + return out + + +class ResBlock(nn.Module): + def __init__(self, num_features): + super(ResBlock, self).__init__() + self.layers = nn.Sequential( + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, act='ReLU', padding=1), + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, padding=1) + ) + + def forward(self, inputs): + return F.relu(self.layers(inputs) + inputs) + + +class DownSample(nn.Module): + def __init__(self, num_features, act, norm, scale=2): + super(DownSample, self).__init__() + if scale == 1: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + else: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + invPixelShuffle(ratio=scale), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + + def forward(self, inputs): + return self.layers(inputs) + + +class ResidualGroup(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act,norm, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [ + ResBlock(n_feat) for _ in range(n_resblocks)] + + modules_body.append(ConvBNReLU2D(n_feat, n_feat, kernel_size, padding=1, act=act, norm=norm)) + self.body = nn.Sequential(*modules_body) + self.re_scale = Scale(1) + + def forward(self, x): + res = self.body(x) + return res + self.re_scale(x) + + + +class FreBlock9(nn.Module): + def __init__(self, channels): + super(FreBlock9, self).__init__() + + self.fpre = nn.Conv2d(channels, channels, 1, 1, 0) + self.amp_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.pha_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.post = nn.Conv2d(channels, channels, 1, 1, 0) + + + def forward(self, x): + # print("x: ", x.shape) + _, _, H, W = x.shape + msF = torch.fft.rfft2(self.fpre(x)+1e-8, norm='backward') + + msF_amp = torch.abs(msF) + msF_pha = torch.angle(msF) + # print("msf_amp: ", msF_amp.shape) + amp_fuse = self.amp_fuse(msF_amp) + # print(amp_fuse.shape, msF_amp.shape) + amp_fuse = amp_fuse + msF_amp + pha_fuse = self.pha_fuse(msF_pha) + pha_fuse = pha_fuse + msF_pha + + real = amp_fuse * torch.cos(pha_fuse)+1e-8 + imag = amp_fuse * torch.sin(pha_fuse)+1e-8 + out = torch.complex(real, imag)+1e-8 + out = torch.abs(torch.fft.irfft2(out, s=(H, W), norm='backward')) + out = self.post(out) + out = out + x + out = torch.nan_to_num(out, nan=1e-5, posinf=1e-5, neginf=1e-5) + # print("out: ", out.shape) + return out + +class Attention(nn.Module): + def __init__(self, dim=64, num_heads=8, bias=False): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) + self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias) + self.q = nn.Conv2d(dim, dim , kernel_size=1, bias=bias) + self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x, y): + b, c, h, w = x.shape + + kv = self.kv_dwconv(self.kv(y)) + k, v = kv.chunk(2, dim=1) + q = self.q_dwconv(self.q(x)) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + +class FuseBlock7(nn.Module): + def __init__(self, channels): + super(FuseBlock7, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fre_att = Attention(dim=channels) + self.spa_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + ori = spa + fre = self.fre(fre) + spa = self.spa(spa) + fre = self.fre_att(fre, spa)+fre + spa = self.fre_att(spa, fre)+spa + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class FuseBlock6(nn.Module): + def __init__(self, channels): + super(FuseBlock6, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock7(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock7, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t1_att = Attention(dim=channels) + self.t2_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + t1 = self.t1_att(t1, t2)+t1 + t2 = self.t2_att(t2, t1)+t2 + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock6(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock6, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/networks_fsm/modules.py b/MRI_recon/code/Frequency-Diffusion/networks/networks_fsm/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..bad92a6d6709130c4ad1cc49d5094958d44d7e71 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/networks_fsm/modules.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + +USE_PYTORCH_IN = False + + +###################################################################### +# Superclass of all Modules that take two inputs +###################################################################### +class TwoInputModule(nn.Module): + def forward(self, input1, input2): + raise NotImplementedError + +###################################################################### +# A (sort of) hacky way to create a module that takes two inputs (e.g. x and z) +# and returns one output (say o) defined as follows: +# o = module2.forward(module1.forward(x), z) +# Note that module2 MUST support two inputs as well. +###################################################################### +class MergeModule(TwoInputModule): + def __init__(self, module1, module2): + """ module1 could be any module (e.g. Sequential of several modules) + module2 must accept two inputs + """ + super(MergeModule, self).__init__() + self.module1 = module1 + self.module2 = module2 + + def forward(self, input1, input2): + output1 = self.module1.forward(input1) + output2 = self.module2.forward(output1, input2) + return output2 + +###################################################################### +# A (sort of) hacky way to create a container that takes two inputs (e.g. x and z) +# and applies a sequence of modules (exactly like nn.Sequential) but MergeModule +# is one of its submodules it applies it to both inputs +###################################################################### +class TwoInputSequential(nn.Sequential, TwoInputModule): + def __init__(self, *args): + super(TwoInputSequential, self).__init__(*args) + + def forward(self, input1, input2): + """overloads forward function in parent calss""" + + for module in self._modules.values(): + if isinstance(module, TwoInputModule): + input1 = module.forward(input1, input2) + else: + input1 = module.forward(input1) + return input1 + + +###################################################################### +# A standard instance norm module. +# Since the pytorch instance norm used BatchNorm as a base and thus is +# different from the standard implementation. +###################################################################### +class InstanceNorm(nn.Module): + def __init__(self, num_features, affine=True, eps=1e-5): + """`num_features` number of feature channels + """ + super(InstanceNorm, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + + self.scale = Parameter(torch.Tensor(num_features)) + self.shift = Parameter(torch.Tensor(num_features)) + + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + self.scale.data.normal_(mean=0., std=0.02) + self.shift.data.zero_() + + def forward(self, input): + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + centered_x = x_reshaped - mean + std = torch.rsqrt((centered_x ** 2).mean(2, keepdim=True) + self.eps) + norm_features = (centered_x * std).view(*size) + + # broadcast on the batch dimension, hight and width dimensions + if self.affine: + output = norm_features * self.scale[:,None,None] + self.shift[:,None,None] + else: + output = norm_features + + return output + +InstanceNorm2d = nn.InstanceNorm2d if USE_PYTORCH_IN else InstanceNorm + +###################################################################### +# A module implementing conditional instance norm. +# Takes two inputs: x (input features) and z (latent codes) +###################################################################### +class CondInstanceNorm(TwoInputModule): + def __init__(self, x_dim, z_dim, eps=1e-5): + """`x_dim` dimensionality of x input + `z_dim` dimensionality of z latents + """ + super(CondInstanceNorm, self).__init__() + self.eps = eps + self.shift_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + self.scale_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + + def forward(self, input, noise): + + shift = self.shift_conv.forward(noise) + scale = self.scale_conv.forward(noise) + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + var = x_reshaped.var(2, keepdim=True) + std = torch.rsqrt(var + self.eps) + norm_features = ((x_reshaped - mean) * std).view(*size) + output = norm_features * scale + shift + return output + + +###################################################################### +# A modified resnet block which allows for passing additional noise input +# to be used for conditional instance norm +###################################################################### +class CINResnetBlock(TwoInputModule): + def __init__(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + super(CINResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + for idx, module in enumerate(self.conv_block): + self.add_module(str(idx), module) + + def build_conv_block(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [ + MergeModule( + nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(x_dim, z_dim) + ), + nn.ReLU(True) + ] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + InstanceNorm2d(x_dim, affine=True)] + + return TwoInputSequential(*conv_block) + + def forward(self, x, noise): + out = self.conv_block(x, noise) + out = self.relu(x + out) + return out + +###################################################################### +# Define a resnet block +###################################################################### +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [nn.ReLU(True)] + + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = self.conv_block(x) + out = self.relu(x + out) + return out diff --git a/MRI_recon/code/Frequency-Diffusion/networks/networks_fsm/mynet.py b/MRI_recon/code/Frequency-Diffusion/networks/networks_fsm/mynet.py new file mode 100644 index 0000000000000000000000000000000000000000..6f3734afef3168a620e49722fb3e1a3e64708c42 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/networks_fsm/mynet.py @@ -0,0 +1,376 @@ +import torch +from torch import nn +from . import common_freq as common + + +class TwoBranch(nn.Module): + def __init__(self, num_features, act, base_num_every_group, num_channels): + super(TwoBranch, self).__init__() + + self.num_features = num_features + self.act = act + self.num_channels = num_channels + print("num_channels: ", num_channels) + + num_group = 4 + num_every_group = base_num_every_group + + self.init_T2_frq_branch() + self.init_T2_spa_branch( num_every_group) + self.init_T2_fre_spa_fusion() + + self.init_T1_frq_branch() + self.init_T1_spa_branch( num_every_group) + + self.init_modality_fre_fusion() + self.init_modality_spa_fusion() + + + def init_T2_frq_branch(self, ): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head_fre = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + + self.down1_fre = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_down2_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down2_fre = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo = nn.Sequential(common.FreBlock9(self.num_features)) + + modules_down3_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down3_fre = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo = nn.Sequential(common.FreBlock9(self.num_features)) + + modules_neck_fre = [common.FreBlock9(self.num_features, ) + ] + self.neck_fre = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo = nn.Sequential(common.FreBlock9(self.num_features )) + + modules_up1_fre = [common.UpSampler(2, self.num_features), + common.FreBlock9(self.num_features, ) + ] + self.up1_fre = nn.Sequential(*modules_up1_fre) + self.up1_fre_mo = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_up2_fre = [common.UpSampler(2, self.num_features), + common.FreBlock9(self.num_features, ) + ] + self.up2_fre = nn.Sequential(*modules_up2_fre) + self.up2_fre_mo = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_up3_fre = [common.UpSampler(2, self.num_features), + common.FreBlock9(self.num_features, ) + ] + self.up3_fre = nn.Sequential(*modules_up3_fre) + self.up3_fre_mo = nn.Sequential(common.FreBlock9(self.num_features, )) + + # define tail module + modules_tail_fre = [ + common.ConvBNReLU2D(self.num_features, out_channels=self.num_channels, kernel_size=3, padding=1, + act=self.act)] + self.tail_fre = nn.Sequential(*modules_tail_fre) + + def init_T2_spa_branch(self, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down1 = nn.Sequential(*modules_down1) + + + self.down1_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down2 = nn.Sequential(*modules_down2) + + self.down2_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down3 = nn.Sequential(*modules_down3) + self.down3_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.neck = nn.Sequential(*modules_neck) + + self.neck_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_up1 = [common.UpSampler(2, self.num_features), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.up1 = nn.Sequential(*modules_up1) + + self.up1_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_up2 = [common.UpSampler(2, self.num_features), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.up2 = nn.Sequential(*modules_up2) + self.up2_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + + modules_up3 = [common.UpSampler(2, self.num_features), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.up3 = nn.Sequential(*modules_up3) + self.up3_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + # define tail module + modules_tail = [ + common.ConvBNReLU2D(self.num_features, out_channels=self.num_channels, kernel_size=3, padding=1, + act=self.act)] + + self.tail = nn.Sequential(*modules_tail) + + def init_T2_fre_spa_fusion(self, ): + ### T2 frq & spa fusion part + conv_fuse = [] + for i in range(14): + conv_fuse.append(common.FuseBlock7(self.num_features)) + self.conv_fuse = nn.Sequential(*conv_fuse) + + def init_T1_frq_branch(self, ): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head_fre_T1 = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + + self.down1_fre_T1 = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_down2_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down2_fre_T1 = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_down3_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down3_fre_T1 = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_neck_fre = [common.FreBlock9(self.num_features, ) + ] + self.neck_fre_T1 = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + def init_T1_spa_branch(self, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head_T1 = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down1_T1 = nn.Sequential(*modules_down1) + + + self.down1_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down2_T1 = nn.Sequential(*modules_down2) + + self.down2_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down3_T1 = nn.Sequential(*modules_down3) + self.down3_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.neck_T1 = nn.Sequential(*modules_neck) + + self.neck_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + + def init_modality_fre_fusion(self, ): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(self.num_features)) + self.conv_fuse_fre = nn.Sequential(*conv_fuse) + + def init_modality_spa_fusion(self, ): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(self.num_features)) + self.conv_fuse_spa = nn.Sequential(*conv_fuse) + + def forward(self, main, aux, t): + #### T1 fre encoder + t1_fre = self.head_fre_T1(aux) # 128 + + down1_fre_t1 = self.down1_fre_T1(t1_fre)# 64 + down1_fre_mo_t1 = self.down1_fre_mo_T1(down1_fre_t1) + + down2_fre_t1 = self.down2_fre_T1(down1_fre_mo_t1) # 32 + down2_fre_mo_t1 = self.down2_fre_mo_T1(down2_fre_t1) + + down3_fre_t1 = self.down3_fre_T1(down2_fre_mo_t1) # 16 + down3_fre_mo_t1 = self.down3_fre_mo_T1(down3_fre_t1) + + neck_fre_t1 = self.neck_fre_T1(down3_fre_mo_t1) # 16 + neck_fre_mo_t1 = self.neck_fre_mo_T1(neck_fre_t1) + + + #### T2 fre encoder and T1 & T2 fre fusion + x_fre = self.head_fre(main) # 128 + x_fre_fuse = self.conv_fuse_fre[0](t1_fre, x_fre) + + down1_fre = self.down1_fre(x_fre_fuse)# 64 + down1_fre_mo = self.down1_fre_mo(down1_fre) + down1_fre_mo_fuse = self.conv_fuse_fre[1](down1_fre_mo_t1, down1_fre_mo) + + down2_fre = self.down2_fre(down1_fre_mo_fuse) # 32 + down2_fre_mo = self.down2_fre_mo(down2_fre) + down2_fre_mo_fuse = self.conv_fuse_fre[2](down2_fre_mo_t1, down2_fre_mo) + + down3_fre = self.down3_fre(down2_fre_mo_fuse) # 16 + down3_fre_mo = self.down3_fre_mo(down3_fre) + down3_fre_mo_fuse = self.conv_fuse_fre[3](down3_fre_mo_t1, down3_fre_mo) + + neck_fre = self.neck_fre(down3_fre_mo_fuse) # 16 + neck_fre_mo = self.neck_fre_mo(neck_fre) + neck_fre_mo_fuse = self.conv_fuse_fre[4](neck_fre_mo_t1, neck_fre_mo) + + + #### T2 fre decoder + neck_fre_mo = neck_fre_mo_fuse + down3_fre_mo_fuse + + up1_fre = self.up1_fre(neck_fre_mo) # 32 + up1_fre_mo = self.up1_fre_mo(up1_fre) + up1_fre_mo = up1_fre_mo + down2_fre_mo_fuse + + up2_fre = self.up2_fre(up1_fre_mo) # 64 + up2_fre_mo = self.up2_fre_mo(up2_fre) + up2_fre_mo = up2_fre_mo + down1_fre_mo_fuse + + up3_fre = self.up3_fre(up2_fre_mo) # 128 + up3_fre_mo = self.up3_fre_mo(up3_fre) + up3_fre_mo = up3_fre_mo + x_fre_fuse + + res_fre = self.tail_fre(up3_fre_mo) + + #### T1 spa encoder + x_t1 = self.head_T1(aux) # 128 + + down1_t1 = self.down1_T1(x_t1) # 64 + down1_mo_t1 = self.down1_mo_T1(down1_t1) + + down2_t1 = self.down2_T1(down1_mo_t1) # 32 + down2_mo_t1 = self.down2_mo_T1(down2_t1) # 32 + + down3_t1 = self.down3_T1(down2_mo_t1) # 16 + down3_mo_t1 = self.down3_mo_T1(down3_t1) # 16 + + neck_t1 = self.neck_T1(down3_mo_t1) # 16 + neck_mo_t1 = self.neck_mo_T1(neck_t1) + + #### T2 spa encoder and fusion + x = self.head(main) # 128 + + x_fuse = self.conv_fuse_spa[0](x_t1, x) + down1 = self.down1(x_fuse) # 64 + down1_fuse = self.conv_fuse[0](down1_fre, down1) + down1_mo = self.down1_mo(down1_fuse) + down1_fuse_mo = self.conv_fuse[1](down1_fre_mo_fuse, down1_mo) + + down1_fuse_mo_fuse = self.conv_fuse_spa[1](down1_mo_t1, down1_fuse_mo) + down2 = self.down2(down1_fuse_mo_fuse) # 32 + down2_fuse = self.conv_fuse[2](down2_fre, down2) + down2_mo = self.down2_mo(down2_fuse) # 32 + down2_fuse_mo = self.conv_fuse[3](down2_fre_mo, down2_mo) + + down2_fuse_mo_fuse = self.conv_fuse_spa[2](down2_mo_t1, down2_fuse_mo) + down3 = self.down3(down2_fuse_mo_fuse) # 16 + down3_fuse = self.conv_fuse[4](down3_fre, down3) + down3_mo = self.down3_mo(down3_fuse) # 16 + down3_fuse_mo = self.conv_fuse[5](down3_fre_mo, down3_mo) + + down3_fuse_mo_fuse = self.conv_fuse_spa[3](down3_mo_t1, down3_fuse_mo) + neck = self.neck(down3_fuse_mo_fuse) # 16 + neck_fuse = self.conv_fuse[6](neck_fre, neck) + neck_mo = self.neck_mo(neck_fuse) + neck_mo = neck_mo + down3_mo + neck_fuse_mo = self.conv_fuse[7](neck_fre_mo, neck_mo) + + neck_fuse_mo_fuse = self.conv_fuse_spa[4](neck_mo_t1, neck_fuse_mo) + #### T2 spa decoder + up1 = self.up1(neck_fuse_mo_fuse) # 32 + up1_fuse = self.conv_fuse[8](up1_fre, up1) + up1_mo = self.up1_mo(up1_fuse) + up1_mo = up1_mo + down2_mo + up1_fuse_mo = self.conv_fuse[9](up1_fre_mo, up1_mo) + + up2 = self.up2(up1_fuse_mo) # 64 + up2_fuse = self.conv_fuse[10](up2_fre, up2) + up2_mo = self.up2_mo(up2_fuse) + up2_mo = up2_mo + down1_mo + up2_fuse_mo = self.conv_fuse[11](up2_fre_mo, up2_mo) + + up3 = self.up3(up2_fuse_mo) # 128 + + up3_fuse = self.conv_fuse[12](up3_fre, up3) + up3_mo = self.up3_mo(up3_fuse) + + up3_mo = up3_mo + x + up3_fuse_mo = self.conv_fuse[13](up3_fre_mo, up3_mo) + + res = self.tail(up3_fuse_mo) + + return res + main, res_fre + main + +def make_model(): + return TwoBranch() + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/new_twobranch_model.py b/MRI_recon/code/Frequency-Diffusion/networks/new_twobranch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2c99525dd4649886a0b5da769016c87aa5d6245f --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/new_twobranch_model.py @@ -0,0 +1,515 @@ +import math +import torch +import torch.nn as nn + +from .st_branch_model_spa.utils import AMPLoss, PhaLoss + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class TransformerBlock(nn.Module): + def __init__(self, embed_dim, num_heads, feedforward_dim, dropout=0.1): + super(TransformerBlock, self).__init__() + self.transformer = nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=num_heads, + dim_feedforward=feedforward_dim, + dropout=dropout, + batch_first=True, + ) + + def forward(self, x): + return self.transformer(x) + +class CrossAttention(nn.Module): + def __init__(self, embed_dim, num_heads): + super(CrossAttention, self).__init__() + self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True) + + def forward(self, query, key, value): + attn_output, attn_weights = self.attention(query, key, value) + return attn_output, attn_weights + + + +class FreBlock(nn.Module): + def __init__(self, channels, embed_dim = 256): + super(FreBlock, self).__init__() + + num_heads = 8 + + self.fpre = nn.Conv2d(channels, channels, 1, 1, 0) + self.amp_conv = nn.Sequential( + nn.Conv2d(channels, channels, 5, 1, 2), + nn.LeakyReLU(0.1, inplace=True) + ) + self.pha_conv = nn.Sequential( + nn.Conv2d(channels, channels, 5, 1, 2), + nn.LeakyReLU(0.1, inplace=True) + ) + + + self.amp_fuse = nn.Sequential( + TransformerBlock(embed_dim, num_heads, embed_dim), + nn.LeakyReLU(0.1, inplace=True), + # TransformerBlock(embed_dim, num_heads, embed_dim), + # nn.ReLU() + ) + + self.pha_fuse = nn.Sequential( + TransformerBlock(embed_dim, num_heads, embed_dim), + nn.LeakyReLU(0.1, inplace=True), + # TransformerBlock(embed_dim, num_heads, embed_dim) + ) + + self.post = nn.Conv2d(channels, channels, 1, 1, 0) + + + # self.transformer = TransformerBlock(embed_dim, num_heads, feedforward_dim) + self.cross_attention = CrossAttention(embed_dim, num_heads) + self.cross_attention_2 = CrossAttention(embed_dim, num_heads) + + + def forward(self, x, k=None): + _, _, H, W = x.shape + # k shape, msF_component_fuse shape torch.Size([24, 1, 128, 128]) torch.Size([24, 256, 16, 9]) + + # rfft2 输出的形状 (半频谱): (rows, cols//2 + 1) + # half_W = W // 2 + 1 + # down-scale + # k = torch.nn.functional.interpolate(k, size=(H, W), mode='bilinear', + # align_corners=False).cuda() + # k = k[...,:half_W] + + + fpre = self.fpre(x) + msF = torch.fft. rfft2(fpre + 1e-8, norm='ortho') + msF = torch.fft.fftshift(msF, dim=[2, 3]) + + msF_ori= msF.clone() # * k + + msF_amp = torch.abs(msF) + msF_pha = torch.angle(msF) + + + msF_amp = self.amp_conv(msF_amp) + msF_pha = self.pha_conv(msF_pha) + + batch_size, channels, height, width = msF_amp.shape + msF_amp_flatten = msF_amp.view(batch_size, channels, -1).permute(0, 2, 1) # (batch_size, H*W, channels) + msF_pha_flatten = msF_pha.view(batch_size, channels, -1).permute(0, 2, 1) # (batch_size, H*W, channels) + # print("msF_amp_flatten shape", msF_amp_flatten.shape) + + # channels = msF_amp.shape[1] + msF_amp_flatten, _ = self.cross_attention( msF_amp_flatten, msF_pha_flatten, msF_pha_flatten) + msF_pha_flatten, _ = self.cross_attention_2(msF_pha_flatten, msF_amp_flatten, msF_amp_flatten) + + amplitude_features = self.amp_fuse(msF_amp_flatten) # + msF_component + angle_features = self.pha_fuse(msF_pha_flatten) # + msF_component + + # cross attention + amp_fuse = amplitude_features.permute(0, 2, 1).view(batch_size, channels, height, width) + pha_fuse = angle_features.permute(0, 2, 1).view(batch_size, channels, height, width) + + amp_fuse = nn.ReLU()(amp_fuse) + real = amp_fuse * torch.cos(pha_fuse) + 1e-8 + imag = amp_fuse * torch.sin(pha_fuse) + 1e-8 + + out = torch.complex(real, imag) + 1e-8 + out = out + msF_ori # * (1 - k) + + out = torch.fft.ifftshift(out, dim=[2, 3]) + out = torch.abs(torch.fft.irfft2(out, s=(H, W), norm='ortho')) + out = self.post(out) + + + + out = torch.nan_to_num(out, nan=1e-5, posinf=1, neginf=-1) + + return out + + +class Branch(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + + + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels * 2, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,) + ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + fre = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + fre.append(FreBlock(channels=block_in)) + down = nn.Module() + down.block = block + down.attn = attn + down.fre = fre + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + self.mid_fre = FreBlock(channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + fre = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + fre.append(FreBlock(channels=block_in)) + up = nn.Module() + up.block = block + up.attn = attn + up.fre = fre + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + + + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + self.spatial = Branch(ch=ch, out_ch=out_ch, ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + dropout=dropout, resamp_with_conv=resamp_with_conv, + in_channels=in_channels, resolution=resolution) + + self.amploss = AMPLoss() # .to(self.device, non_blocking=True) + self.phaloss = PhaLoss() # .to(self.device, non_blocking=True) + + self.use_front_fre = False + self.use_after_fre = False + print("=== use front fre", self.use_front_fre) # NAN + print("=== use after fre", self.use_after_fre) # use_after_fre_ BUG NAN + + def forward(self, x, aux, k, t): + assert x.shape[2] == x.shape[3] == self.resolution + + # k = k.to(x.device) + + # timestep embedding + temp = None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + + x_in = torch.cat((x, aux), dim=1) + + # spatial downsampling + hs = [self.spatial.conv_in(x_in)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.spatial.down[i_level].block[i_block](hs[-1], temb) + if len(self.spatial.down[i_level].attn) > 0: + if self.use_front_fre: + h = self.spatial.down[i_level].fre[i_block](h, k) + h = self.spatial.down[i_level].attn[i_block](h) + + if self.use_after_fre: + h = self.spatial.down[i_level].fre[i_block](h, k) + h + + hs.append(h) + + if i_level != self.num_resolutions-1: + hs.append(self.spatial.down[i_level].downsample(hs[-1])) + + # spatial middle + h = hs[-1] + h = self.spatial.mid.block_1(h, temb) + h = self.spatial.mid.attn_1(h) + h = self.spatial.mid.block_2(h, temb) + + # if self.use_front_fre or self.use_after_fre: + # h = self.spatial.mid_fre(h, k) # + h # NAN?? + + # spatial upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.spatial.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.spatial.up[i_level].attn) > 0: + if self.use_front_fre: + h = self.spatial.up[i_level].fre[i_block](h, k) + h = self.spatial.up[i_level].attn[i_block](h) + if self.use_after_fre: + h = self.spatial.up[i_level].fre[i_block](h, k) + h + + # TODO residual + # h += hs.pop() + + if i_level != 0: + h = self.spatial.up[i_level].upsample(h) + + # spatial end + h = self.spatial.norm_out(h) + h = nonlinearity(h) + h = self.spatial.conv_out(h) + + return h diff --git a/MRI_recon/code/Frequency-Diffusion/networks/st_branch_model/common_freq.py b/MRI_recon/code/Frequency-Diffusion/networks/st_branch_model/common_freq.py new file mode 100644 index 0000000000000000000000000000000000000000..f209b9da0894b884345487dde2ebe9344ca131f3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/st_branch_model/common_freq.py @@ -0,0 +1,456 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange +from torch.fft import * + +def frequency_transform(x_input, pixel_range='-1_1', to_frequency=True): + if to_frequency: + if pixel_range == '0_1': + pass + + elif pixel_range == '-1_1': + # x_start (-1, 1) --> (0, 1) + x_start = (x_input + 1) / 2 + + elif pixel_range == 'complex': + x_start = torch.complex(x_input[:, :1, ...], x_input[:, 1:, ...]) + + else: + raise ValueError(f"Unknown pixel range {pixel_range}.") + + fft = fftshift(fft2(x_input)) + return fft + + else: + x_ksu = ifft2(ifftshift(x_input)) + + if pixel_range == '0_1': + x_ksu = torch.abs(x_ksu) + + elif pixel_range == '-1_1': + x_ksu = torch.abs(x_ksu) + # x_ksu (0, 1) --> (-1, 1) + x_ksu = x_ksu * 2 - 1 + + elif pixel_range == 'complex': + x_ksu = torch.concat((x_ksu.real, x_ksu.imag), dim=1) + else: + raise ValueError(f"Unknown pixel range {pixel_range}.") + + return x_ksu + + +class Scale(nn.Module): + + def __init__(self, init_value=1e-3): + super().__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale + + +class invPixelShuffle(nn.Module): + def __init__(self, ratio=2): + super(invPixelShuffle, self).__init__() + self.ratio = ratio + + def forward(self, tensor): + ratio = self.ratio + b = tensor.size(0) + ch = tensor.size(1) + y = tensor.size(2) + x = tensor.size(3) + assert x % ratio == 0 and y % ratio == 0, 'x, y, ratio : {}, {}, {}'.format(x, y, ratio) + tensor = tensor.view(b, ch, y // ratio, ratio, x // ratio, ratio).permute(0, 1, 3, 5, 2, 4) + return tensor.contiguous().view(b, -1, y // ratio, x // ratio) + + +class UpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(in_channels=n_feats, out_channels=4 * n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PixelShuffle(upscale_factor=2)) + m.append(nn.PReLU()) + super(UpSampler, self).__init__(*m) + + +class InvUpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(invPixelShuffle(2)) + m.append(nn.Conv2d(in_channels=n_feats * 4, out_channels=n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PReLU()) + super(InvUpSampler, self).__init__(*m) + + + +class AdaptiveNorm(nn.Module): + def __init__(self, n): + super(AdaptiveNorm, self).__init__() + + self.w_0 = nn.Parameter(torch.Tensor([1.0])) + self.w_1 = nn.Parameter(torch.Tensor([0.0])) + + self.bn = nn.BatchNorm2d(n, momentum=0.999, eps=0.001) + + def forward(self, x): + return self.w_0 * x + self.w_1 * self.bn(x) + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + +class ConvBNReLU2D(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, + bias=False, act=None, norm=None, temb_ch=None): + super(ConvBNReLU2D, self).__init__() + + self.layers = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + + self.temb_proj = None + if temb_ch != None: + self.temb_proj = torch.nn.Linear(temb_ch, + out_channels) + + + self.act = None + self.norm = None + if norm == 'BN': + self.norm = torch.nn.BatchNorm2d(out_channels) + elif norm == 'IN': + self.norm = torch.nn.InstanceNorm2d(out_channels) + elif norm == 'GN': + self.norm = torch.nn.GroupNorm(2, out_channels) + elif norm == 'WN': + self.layers = torch.nn.utils.weight_norm(self.layers) + elif norm == 'Adaptive': + self.norm = AdaptiveNorm(n=out_channels) + + if act == 'PReLU': + self.act = torch.nn.PReLU() + elif act == 'SELU': + self.act = torch.nn.SELU(True) + elif act == 'LeakyReLU': + self.act = torch.nn.LeakyReLU(negative_slope=0.02, inplace=True) + elif act == 'ELU': + self.act = torch.nn.ELU(inplace=True) + elif act == 'ReLU': + self.act = torch.nn.ReLU(True) + elif act == 'Tanh': + self.act = torch.nn.Tanh() + elif act == 'Sigmoid': + self.act = torch.nn.Sigmoid() + elif act == 'SoftMax': + self.act = torch.nn.Softmax2d() + + def forward(self, inputs, temp=None): + + out = self.layers(inputs) + + if self.temb_proj != None and temp !=None: + out = out + self.temb_proj(nonlinearity(temp))[:, :, None, None] + + if self.norm is not None: + out = self.norm(out) + if self.act is not None: + out = self.act(out) + return out + + +class ResBlock(nn.Module): + def __init__(self, num_features, temb_ch): + super(ResBlock, self).__init__() + self.layers_1 = \ + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, + act='ReLU', padding=1, temb_ch=temb_ch) + self.layers_2 = \ + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, + padding=1, temb_ch=temb_ch) + + + + def forward(self, inputs, temp=None): + out = self.layers_1(inputs, temp) + out = self.layers_2(out, temp) + + return F.relu(out + inputs) + + +class DownSample(nn.Module): + def __init__(self, num_features, act, norm, scale=2): + super(DownSample, self).__init__() + if scale == 1: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + else: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + invPixelShuffle(ratio=scale), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + + def forward(self, inputs): + return self.layers(inputs) + + +class ResidualGroup(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act,norm, n_resblocks, temb_ch=None): + super(ResidualGroup, self).__init__() + self.body = nn.ModuleList([ + ResBlock(n_feat, temb_ch) for _ in range(n_resblocks)]) + + + self.end = ConvBNReLU2D(n_feat, n_feat, kernel_size, padding=1, act=act, + norm=norm, temb_ch=temb_ch) + + + self.re_scale = Scale(1) + + def forward(self, x, temp): + # res = self.body(x) + res = x + for block in self.body: + res = block(res, temp) + res = self.end(res, temp) + return res + self.re_scale(x) + + + +class FreBlock9(nn.Module): + def __init__(self, channels): + super(FreBlock9, self).__init__() + + self.fpre = nn.Conv2d(channels, channels, 1, 1, 0) + self.amp_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.pha_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.post = nn.Conv2d(channels, channels, 1, 1, 0) + + + def forward(self, x, temp=None): + # print("x: ", x.shape) + _, _, H, W = x.shape + msF = torch.fft.rfft2(self.fpre(x)+1e-8, norm='backward') + + msF_amp = torch.abs(msF) + msF_pha = torch.angle(msF) + # print("msf_amp: ", msF_amp.shape) + amp_fuse = self.amp_fuse(msF_amp) + # print(amp_fuse.shape, msF_amp.shape) + amp_fuse = amp_fuse + msF_amp + pha_fuse = self.pha_fuse(msF_pha) + pha_fuse = pha_fuse + msF_pha + + real = amp_fuse * torch.cos(pha_fuse)+1e-8 + imag = amp_fuse * torch.sin(pha_fuse)+1e-8 + out = torch.complex(real, imag)+1e-8 + out = torch.abs(torch.fft.irfft2(out, s=(H, W), norm='backward')) + out = self.post(out) + out = out + x + out = torch.nan_to_num(out, nan=1e-5, posinf=1e-5, neginf=1e-5) + # print("out: ", out.shape) + return out + +class Attention(nn.Module): + def __init__(self, dim=64, num_heads=8, bias=False): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) + self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias) + self.q = nn.Conv2d(dim, dim , kernel_size=1, bias=bias) + self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x, y): + b, c, h, w = x.shape + + kv = self.kv_dwconv(self.kv(y)) + k, v = kv.chunk(2, dim=1) + q = self.q_dwconv(self.q(x)) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + +class FuseBlock7(nn.Module): + def __init__(self, channels): + super(FuseBlock7, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fre_att = Attention(dim=channels) + self.spa_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + ori = spa + fre = self.fre(fre) + spa = self.spa(spa) + fre = self.fre_att(fre, spa)+fre + spa = self.fre_att(spa, fre)+spa + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class FuseBlock6(nn.Module): + def __init__(self, channels): + super(FuseBlock6, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock7(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock7, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t1_att = Attention(dim=channels) + self.t2_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + t1 = self.t1_att(t1, t2)+t1 + t2 = self.t2_att(t2, t1)+t2 + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock6(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock6, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/st_branch_model/model.py b/MRI_recon/code/Frequency-Diffusion/networks/st_branch_model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..1ade669aa7079d8bb9a8c65e6190ae92a1bfd639 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/st_branch_model/model.py @@ -0,0 +1,786 @@ +import torch, math +from torch import nn +from . import common_freq as common +import torch.nn.functional as F + +from .utils import adopt_weight, hinge_d_loss, vanilla_d_loss +from metrics.lpips import LPIPS +# from vq_gan_3d.model.codebook import Codebook +import numpy as np + + +def silu(x): + return x * torch.sigmoid(x) + + +class SiLU(nn.Module): + def __init__(self): + super(SiLU, self).__init__() + + def forward(self, x): + return silu(x) + + +class DownBlock(nn.Module): + def __init__(self, num_features, act=True, norm=True, fre_layer=False, + kernel_size=3, reduction = 4, num_every_group=1, temb_ch=None, + spa_norm=None, spa_act=None): + super(DownBlock, self).__init__() + + self.downsample = common.DownSample(num_features, act, norm) + + self.fre_layer = None + self.spa_layer = None + if fre_layer: + self.fre_layer = common.FreBlock9(num_features) + else: + self.spa_layer = common.ResidualGroup( + num_features, kernel_size, reduction, act=spa_act, + n_resblocks=num_every_group, norm=spa_norm, temb_ch=temb_ch) + + def forward(self, x, temp=None): + out = self.downsample(x) + if self.fre_layer is not None: + out = self.fre_layer(out, temp) + else: + out = self.spa_layer(out, temp) + return out + + + +class UpBlock(nn.Module): + def __init__(self, scale, num_features, act=True, norm=True, fre_layer=False, + kernel_size=3, reduction=4, num_every_group=1, temb_ch=None, + spa_norm=None, spa_act=None + ): + super(UpBlock, self).__init__() + + self.upsample = common.UpSampler(scale, num_features) + self.fre_layer = None + self.spa_layer = None + if fre_layer: + self.fre_layer = common.FreBlock9(num_features) + else: + self.spa_layer = common.ResidualGroup( + num_features, kernel_size, reduction, act=spa_act, + n_resblocks=num_every_group, norm=spa_norm, temb_ch=temb_ch) + + + def forward(self, x, temp=None): + out = self.upsample(x) + if self.fre_layer is not None: + out = self.fre_layer(out, temp) + else: + out = self.spa_layer(out, temp) + return out + + + +class DuplicateBlock(nn.Module): + def __init__(self, block, num_of_block, **kwargs): + super(DuplicateBlock, self).__init__() + + self.blocks = nn.ModuleList([block(**kwargs) for _ in range(num_of_block)]) + + def forward(self, x, temp=None): + for block in self.blocks: + x = block(x, temp) + return x + + +class ModelBackbone(nn.Module): + def __init__(self, num_features, act, base_num_every_group, num_channels, temb_ch): + super(ModelBackbone, self).__init__() + + self.num_features = num_features + self.act = act + self.num_channels = num_channels + self.temb_ch = temb_ch + + # self.args = args + num_every_group = base_num_every_group + + self.init_T2_frq_branch() + self.init_T2_spa_branch(num_every_group) + self.init_T2_fre_spa_fusion() + + self.init_T1_frq_branch() + self.init_T1_spa_branch(num_every_group) + + self.init_modality_fre_fusion() + self.init_modality_spa_fusion() + + def init_T2_frq_branch(self): + ### T2frequency branch + self.head_fre = common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act, temb_ch=self.temb_ch) + self.down1_fre = DownBlock(self.num_features, False, False, fre_layer=True) + + self.down1_fre_mo = common.FreBlock9(self.num_features) + + self.down2_fre = DownBlock(self.num_features, False, False, fre_layer=True) + + self.down2_fre_mo = common.FreBlock9(self.num_features) + + self.down3_fre = DownBlock(self.num_features, False, False, fre_layer=True) + + self.down3_fre_mo = common.FreBlock9(self.num_features) + + self.neck_fre = common.FreBlock9(self.num_features) + + self.neck_fre_mo = common.FreBlock9(self.num_features) + + ### T2frequency branch + self.head_fre = common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act, temb_ch=self.temb_ch) + self.down1_fre = DownBlock(self.num_features, False, False, fre_layer=True) + + self.down1_fre_mo = common.FreBlock9(self.num_features) + + self.down2_fre = DownBlock(self.num_features, False, False, fre_layer=True) + + self.down2_fre_mo = common.FreBlock9(self.num_features) + + self.down3_fre = DownBlock(self.num_features, False, False, fre_layer=True) + + self.down3_fre_mo = common.FreBlock9(self.num_features) + + self.neck_fre = common.FreBlock9(self.num_features) + + self.neck_fre_mo = common.FreBlock9(self.num_features) + + self.up1_fre = UpBlock(2, self.num_features, fre_layer=True) + self.up1_fre_mo = common.FreBlock9(self.num_features) + + + self.up2_fre = UpBlock(2, self.num_features, fre_layer=True) + self.up2_fre_mo = common.FreBlock9(self.num_features) + + + self.up3_fre = UpBlock(2, self.num_features, fre_layer=True) + self.up3_fre_mo = nn.Sequential(common.FreBlock9(self.num_features)) + + # define tail module + self.tail_fre = common.ConvBNReLU2D(self.num_features, out_channels=self.num_channels, kernel_size=3, padding=1, + act=self.act, temb_ch=self.temb_ch) + + + + self.up1_fre = UpBlock(2, self.num_features, fre_layer=True) + self.up1_fre_mo = common.FreBlock9(self.num_features) + + self.up2_fre = UpBlock(2, self.num_features, fre_layer=True) + + self.up2_fre_mo = common.FreBlock9(self.num_features) + + self.up3_fre = UpBlock(2, self.num_features, fre_layer=True) + self.up3_fre_mo = common.FreBlock9(self.num_features) + + # define tail module + self.tail_fre = common.ConvBNReLU2D(self.num_features, out_channels=self.num_channels, kernel_size=3, padding=1, + act=self.act, temb_ch=self.temb_ch) + + + + def init_T2_spa_branch(self, num_every_group): + ### spatial branch + + self.head = common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act, temb_ch=self.temb_ch) + + self.down1 = DownBlock(self.num_features, act=False, norm=False, fre_layer=False, + kernel_size=3, reduction=4, + num_every_group=num_every_group, temb_ch=None, spa_norm=None, spa_act=self.act) + + + self.down1_mo = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + self.down2 = DownBlock(self.num_features, act=False, norm=False, fre_layer=False, + kernel_size=3, reduction=4, + num_every_group=num_every_group, temb_ch=None, spa_norm=None, spa_act=self.act) + + self.down2_mo = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + self.down3 = DownBlock(self.num_features, act=False, norm=False, fre_layer=False, + kernel_size=3, reduction=4, + num_every_group=num_every_group, temb_ch=None, spa_norm=None, spa_act=self.act) + + self.down3_mo = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + self.neck = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + self.neck_mo = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + self.up1 = UpBlock(2, self.num_features, act=None, norm=None, fre_layer=False, + kernel_size=3, reduction=4, num_every_group=num_every_group, temb_ch=self.temb_ch, + spa_norm=None, spa_act=self.act) + + + self.up1_mo = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + self.up2 = UpBlock(2, self.num_features, act=None, norm=None, fre_layer=False, + kernel_size=3, reduction=4, num_every_group=num_every_group, temb_ch=self.temb_ch, + spa_norm=None, spa_act=self.act) + + self.up2_mo = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + + self.up3 = UpBlock(2, self.num_features, act=None, norm=None, fre_layer=False, + kernel_size=3, reduction=4, num_every_group=num_every_group, temb_ch=self.temb_ch) + + self.up3_mo = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + # define tail module + self.tail = common.ConvBNReLU2D(self.num_features, out_channels=self.num_channels, kernel_size=3, padding=1, + act=self.act, temb_ch=self.temb_ch) + + + + + + def init_T1_frq_branch(self): + ### T2frequency branch + self.head_fre_T1 = common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act, temb_ch=self.temb_ch) + + + self.down1_fre_T1 = DownBlock(self.num_features, act=False, norm=False, fre_layer=True, temb_ch=self.temb_ch) + + self.down1_fre_mo_T1 = common.FreBlock9(self.num_features) + + + self.down2_fre_T1 = DownBlock(self.num_features, act=False, norm=False, fre_layer=True, temb_ch=self.temb_ch) + + self.down2_fre_mo_T1 = common.FreBlock9(self.num_features) + + + self.down3_fre_T1 = DownBlock(self.num_features, act=False, norm=False, fre_layer=True, temb_ch=self.temb_ch) + self.down3_fre_mo_T1 = common.FreBlock9(self.num_features) + + + self.neck_fre_T1 = common.FreBlock9(self.num_features) + self.neck_fre_mo_T1 = common.FreBlock9(self.num_features) + + + + + def init_T1_spa_branch(self, num_every_group): + ### spatial branch + + self.head_T1 = common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act, temb_ch=self.temb_ch) + + + self.down1_T1 = DownBlock(self.num_features, act=False, norm=False, fre_layer=False, + kernel_size=3, reduction=4, num_every_group=num_every_group, temb_ch=self.temb_ch, + spa_norm=None, spa_act=self.act) + + + self.down1_mo_T1 = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, + n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + + self.down2_T1 = DownBlock(self.num_features, act=False, norm=False, fre_layer=False, + kernel_size=3, reduction=4, num_every_group=num_every_group, temb_ch=self.temb_ch, + spa_norm=None, spa_act=self.act) + + + self.down2_mo_T1 = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, + norm=None, temb_ch=self.temb_ch) + + + self.down3_T1 = DownBlock(self.num_features, act=False, norm=False, fre_layer=False, + kernel_size=3, reduction=4, num_every_group=num_every_group, temb_ch=self.temb_ch, + spa_norm=None, spa_act=self.act) + + + self.down3_mo_T1 = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, + norm=None, temb_ch=self.temb_ch) + + + self.neck_T1 = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + self.neck_mo_T1 = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, + norm=None, temb_ch=self.temb_ch) + + + def init_T2_fre_spa_fusion(self): + ### T2 frq & spa fusion part + conv_fuse = [] + for i in range(14): + conv_fuse.append(common.FuseBlock7(self.num_features)) + self.conv_fuse = nn.Sequential(*conv_fuse) + + def init_modality_fre_fusion(self): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(self.num_features)) + self.conv_fuse_fre = nn.Sequential(*conv_fuse) + + def init_modality_spa_fusion(self): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(self.num_features)) + self.conv_fuse_spa = nn.Sequential(*conv_fuse) + + + # def init_T2_fre_spa_fusion(self): + # ### T2 frq & spa fusion part + # self.conv_fuse = DuplicateBlock(common.FuseBlock7, 14, + # channels=self.num_features) + # + # def init_modality_fre_fusion(self): + # self.conv_fuse_fre = DuplicateBlock(common.FuseBlock6, 5, channels=self.num_features) + # + # def init_modality_spa_fusion(self): + # self.conv_fuse_spa = DuplicateBlock(common.FuseBlock6, 5, channels=self.num_features) + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + + +class TwoBranchModel(nn.Module): + def __init__(self, num_features, act, base_num_every_group, num_channels): + super(TwoBranchModel, self).__init__() + + num_group = 4 + self.use_fre_mix = False + self.ch = num_channels + self.temb_ch = num_channels * 4 + + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(num_channels, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ]) + + + self.model = ModelBackbone(num_features, act, + base_num_every_group, num_channels, + temb_ch=self.temb_ch) + + + + + def forward(self, main, aux, t): + + # timestep embedding + temb =None + + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + # + + #### T1 fre encoder # T1 + temb_fre = None + t1_fre = self.model.head_fre_T1(aux, temb_fre) # 128 + + down1_fre_t1 = self.model.down1_fre_T1(t1_fre, temb_fre)# 64 + down1_fre_mo_t1 = self.model.down1_fre_mo_T1(down1_fre_t1, temb_fre) + + down2_fre_t1 = self.model.down2_fre_T1(down1_fre_mo_t1, temb_fre) # 32 + down2_fre_mo_t1 = self.model.down2_fre_mo_T1(down2_fre_t1, temb_fre) + + down3_fre_t1 = self.model.down3_fre_T1(down2_fre_mo_t1, temb_fre) # 16 + down3_fre_mo_t1 = self.model.down3_fre_mo_T1(down3_fre_t1, temb_fre) + + neck_fre_t1 = self.model.neck_fre_T1(down3_fre_mo_t1, temb_fre) # 16 + neck_fre_mo_t1 = self.model.neck_fre_mo_T1(neck_fre_t1, temb_fre) + + + #### T2 fre encoder and T1 & T2 fre fusion + x_fre = self.model.head_fre(main, temb) # 128 + x_fre_fuse = self.model.conv_fuse_fre[0](t1_fre, x_fre) + + down1_fre = self.model.down1_fre(x_fre_fuse, temb)# 64 + down1_fre_mo = self.model.down1_fre_mo(down1_fre, temb) + down1_fre_mo_fuse = self.model.conv_fuse_fre[1](down1_fre_mo_t1, down1_fre_mo) + + down2_fre = self.model.down2_fre(down1_fre_mo_fuse, temb) # 32 + down2_fre_mo = self.model.down2_fre_mo(down2_fre, temb) + down2_fre_mo_fuse = self.model.conv_fuse_fre[2](down2_fre_mo_t1, down2_fre_mo) + + down3_fre = self.model.down3_fre(down2_fre_mo_fuse, temb) # 16 + down3_fre_mo = self.model.down3_fre_mo(down3_fre, temb) + down3_fre_mo_fuse = self.model.conv_fuse_fre[3](down3_fre_mo_t1, down3_fre_mo) + + neck_fre = self.model.neck_fre(down3_fre_mo_fuse, temb) # 16 + neck_fre_mo = self.model.neck_fre_mo(neck_fre, temb) + neck_fre_mo_fuse = self.model.conv_fuse_fre[4](neck_fre_mo_t1, neck_fre_mo) + + + #### T2 fre decoder + neck_fre_mo = neck_fre_mo_fuse + down3_fre_mo_fuse + + up1_fre = self.model.up1_fre(neck_fre_mo, temb) # 32 + up1_fre_mo = self.model.up1_fre_mo(up1_fre, temb) + up1_fre_mo = up1_fre_mo + down2_fre_mo_fuse + + up2_fre = self.model.up2_fre(up1_fre_mo, temb) # 64 + up2_fre_mo = self.model.up2_fre_mo(up2_fre, temb) + up2_fre_mo = up2_fre_mo + down1_fre_mo_fuse + + up3_fre = self.model.up3_fre(up2_fre_mo, temb) # 128 + up3_fre_mo = self.model.up3_fre_mo(up3_fre, temb) + up3_fre_mo = up3_fre_mo + x_fre_fuse + + res_fre = self.model.tail_fre(up3_fre_mo, temb) + + #### T1 spa encoder + x_t1 = self.model.head_T1(aux, temb) # 128 + + down1_t1 = self.model.down1_T1(x_t1, temb) # 64 + down1_mo_t1 = self.model.down1_mo_T1(down1_t1, temb) + + down2_t1 = self.model.down2_T1(down1_mo_t1, temb) # 32 + down2_mo_t1 = self.model.down2_mo_T1(down2_t1, temb) # 32 + + down3_t1 = self.model.down3_T1(down2_mo_t1, temb) # 16 + down3_mo_t1 = self.model.down3_mo_T1(down3_t1, temb) # 16 + + neck_t1 = self.model.neck_T1(down3_mo_t1, temb) # 16 + neck_mo_t1 = self.model.neck_mo_T1(neck_t1, temb) + + #### T2 spa encoder and fusion + x = self.model.head(main, temb) # 128 + + x_fuse = self.model.conv_fuse_spa[0](x_t1, x) + down1 = self.model.down1(x_fuse, temb) # 64 + down1_fuse = self.model.conv_fuse[0](down1_fre, down1) + down1_mo = self.model.down1_mo(down1_fuse, temb) + down1_fuse_mo = self.model.conv_fuse[1](down1_fre_mo_fuse, down1_mo) + + down1_fuse_mo_fuse = self.model.conv_fuse_spa[1](down1_mo_t1, down1_fuse_mo) + down2 = self.model.down2(down1_fuse_mo_fuse, temb) # 32 + down2_fuse = self.model.conv_fuse[2](down2_fre, down2) + down2_mo = self.model.down2_mo(down2_fuse, temb) # 32 + down2_fuse_mo = self.model.conv_fuse[3](down2_fre_mo, down2_mo) + + down2_fuse_mo_fuse = self.model.conv_fuse_spa[2](down2_mo_t1, down2_fuse_mo) + down3 = self.model.down3(down2_fuse_mo_fuse, temb) # 16 + down3_fuse = self.model.conv_fuse[4](down3_fre, down3) + down3_mo = self.model.down3_mo(down3_fuse, temb) # 16 + down3_fuse_mo = self.model.conv_fuse[5](down3_fre_mo, down3_mo) + + down3_fuse_mo_fuse = self.model.conv_fuse_spa[3](down3_mo_t1, down3_fuse_mo) + neck = self.model.neck(down3_fuse_mo_fuse, temb) # 16 + neck_fuse = self.model.conv_fuse[6](neck_fre, neck) + neck_mo = self.model.neck_mo(neck_fuse, temb) + neck_mo = neck_mo + down3_mo + neck_fuse_mo = self.model.conv_fuse[7](neck_fre_mo, neck_mo) + + neck_fuse_mo_fuse = self.model.conv_fuse_spa[4](neck_mo_t1, neck_fuse_mo) + #### T2 spa decoder + up1 = self.model.up1(neck_fuse_mo_fuse, temb) # 32 + up1_fuse = self.model.conv_fuse[8](up1_fre, up1) + up1_mo = self.model.up1_mo(up1_fuse, temb) + up1_mo = up1_mo + down2_mo + up1_fuse_mo = self.model.conv_fuse[9](up1_fre_mo, up1_mo) + + up2 = self.model.up2(up1_fuse_mo, temb) # 64 + up2_fuse = self.model.conv_fuse[10](up2_fre, up2) + up2_mo = self.model.up2_mo(up2_fuse, temb) + up2_mo = up2_mo + down1_mo + up2_fuse_mo = self.model.conv_fuse[11](up2_fre_mo, up2_mo) + + up3 = self.model.up3(up2_fuse_mo, temb) # 128 + + up3_fuse = self.model.conv_fuse[12](up3_fre, up3) + up3_mo = self.model.up3_mo(up3_fuse, temb) + + up3_mo = up3_mo + x + up3_fuse_mo = self.model.conv_fuse[13](up3_fre_mo, up3_mo) + + res = self.model.tail(up3_fuse_mo, temb) + + # if self.use_res: + # res = res + # res_fre = res_fre + + return res + main, res_fre + main + + + def training_step(self, batch, batch_idx, optimizer_idx): + self.model.train() + + x = batch['image'] + aux = batch['aux'] + + x = x.squeeze(1) + aux = aux.squeeze(1) + + # print(x.shape) + # print(aux.shape) + + # torch.Size([8, 96, 96, 1]) + # torch.Size([16, 1, 96, 96, 96]) + + x = x.permute(0, -1, -3, -2)#.detach() # [B, C, H, W] + aux = aux.permute(0, -1, -3, -2)#.detach() # [B, C, H, W] + + out = self.forward(x, aux) + recon_out = out['recon_out'] + recon_fre = out['recon_fre'] + + if optimizer_idx == 0: + fft_weight = 0.01 + use_dis = False + recon_out_loss = self.get_recon_loss(recon_out, x, tag="recon_out", use_dis=use_dis) + recon_fre_loss = self.get_recon_loss(recon_fre, x, tag="recon_fre", use_dis=use_dis) + # amp = self.amploss(recon_fre, x) + # pha = self.phaloss(recon_fre, x) + loss = recon_out_loss + recon_fre_loss #+ fft_weight * ( amp + pha ) + + elif optimizer_idx == 1: + loss = self.get_dis_loss(recon_out, x, tag="dis") + + # print("loss = ", loss) + + return loss + + + def get_dis_loss(self, recon, target, tag="dis"): + B, C, H, W = recon.shape + # Selects one random 2D image from each 3D Image + + logits_image_real, _ = self.image_discriminator(target.detach()) + logits_image_fake, _ = self.image_discriminator(recon.detach()) + + print("logits_image_real = ", torch.mean(logits_image_real)) + print("logits_image_fake = ", torch.mean(logits_image_fake)) + + d_image_loss = self.disc_loss(logits_image_real , logits_image_fake) + # print("d_image_loss = ", d_image_loss) + + # d_video_loss = self.disc_loss(logits_video_real, logits_video_fake) + disc_factor = adopt_weight( + self.global_step, threshold=self.discriminator_iter_start) + discloss = disc_factor * (self.image_gan_weight * d_image_loss ) + + self.log(f"train/{tag}/logits_image_real", logits_image_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log(f"train/{tag}/logits_image_fake", logits_image_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log(f"train/{tag}/d_image_loss", d_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log(f"train/{tag}/disc_loss", discloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + return discloss + + + def get_recon_loss(self, recon, target, tag="recon_out", use_dis=True): + recon_loss = F.l1_loss(recon, target) * self.l1_weight + # recon_loss = ((recon - target)**2).mean() * self.l1_weight + + # Perceptual loss + perceptual_loss = 0 + aeloss = 0 + image_gan_feat_loss = 0 + g_image_loss = 0 + + # Slice it into T, H, W random slices + if self.perceptual_weight > 0: + B, C, H, W = recon.shape + # Selects one random 2D image from each 3D Image + + perceptual_loss = self.perceptual_model(recon, target).mean() * self.perceptual_weight + recon_loss += perceptual_loss + + + if use_dis: + # Discriminator loss (turned on after a certain epoch) + logits_image_fake, pred_image_fake = self.image_discriminator(recon) + # logits_video_fake, pred_video_fake = self.video_discriminator(recon) + g_image_loss = -torch.mean(logits_image_fake) + # g_video_loss = -torch.mean(logits_video_fake) + g_loss = self.image_gan_weight * g_image_loss + + disc_factor = adopt_weight( + self.global_step, threshold=self.discriminator_iter_start) + aeloss = disc_factor * g_loss + + # GAN feature matching loss - tune features such that we get the same prediction result on the discriminator + image_gan_feat_loss = 0 + feat_weights = 4.0 / (3 + 1) + if self.image_gan_weight > 0: + logits_image_real, pred_image_real = self.image_discriminator( recon) + for i in range(len(pred_image_fake)-1): + image_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_image_fake[i], pred_image_real[i].detach( + )) * (self.image_gan_weight > 0) + + gan_feat_loss = disc_factor * self.gan_feat_weight * (image_gan_feat_loss) + recon_loss += gan_feat_loss + aeloss + + self.log(f"train/{tag}/g_image_loss", g_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log(f"train/{tag}/image_gan_feat_loss", image_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log(f"train/{tag}/aeloss", aeloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log(f"train/{tag}/perceptual_loss", perceptual_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log(f"train/{tag}/recon_loss", recon_loss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + return recon_loss + + +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw - 1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ + + +class NLayerDiscriminator3D(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator3D, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ + + diff --git a/MRI_recon/code/Frequency-Diffusion/networks/st_branch_model/utils.py b/MRI_recon/code/Frequency-Diffusion/networks/st_branch_model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c92a018e6b6409617ff7982339736b9db36c7fa --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/networks/st_branch_model/utils.py @@ -0,0 +1,220 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import warnings +import torch +import imageio + +import math +import numpy as np +import skvideo.io + +import sys +import pdb as pdb_original +import SimpleITK as sitk +import logging +from torch import nn +import torch.nn.functional as F + +import imageio.core.util +logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) + + +class ForkedPdb(pdb_original.Pdb): + """A Pdb subclass that may be used + from a forked multiprocessing child + + """ + + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + pdb_original.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + +# Shifts src_tf dim to dest dim +# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + + dims = list(range(n_dims)) + del dims[src_dim] + + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + + +# reshapes tensor start from dim i (inclusive) +# to dim j (exclusive) to the desired shape +# e.g. if x.shape = (b, thw, c) then +# view_range(x, 1, 2, (t, h, w)) returns +# x of shape (b, t, h, w, c) +def view_range(x, i, j, shape): + shape = tuple(shape) + + n_dims = len(x.shape) + if i < 0: + i = n_dims + i + + if j is None: + j = n_dims + elif j < 0: + j = n_dims + j + + assert 0 <= i < j <= n_dims + + x_shape = x.shape + target_shape = x_shape[:i] + shape + x_shape[j:] + return x.view(target_shape) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def tensor_slice(x, begin, size): + assert all([b >= 0 for b in begin]) + size = [l - b if s == -1 else s + for s, b, l in zip(size, begin, x.shape)] + assert all([s >= 0 for s in size]) + + slices = [slice(b, b + s) for b, s in zip(begin, size)] + return x[slices] + + +def adopt_weight(global_step, threshold=0, value=0.): + weight = 1 + if global_step < threshold: + weight = value + return weight + + +def save_video_grid(video, fname, nrow=None, fps=6): + b, c, t, h, w = video.shape + video = video.permute(0, 2, 3, 4, 1) + video = (video.cpu().numpy() * 255).astype('uint8') + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = np.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype='uint8') + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + video = [] + for i in range(t): + video.append(video_grid[i]) + imageio.mimsave(fname, video, fps=fps) + ## skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'}) + #print('saved videos to', fname) + + +def comp_getattr(args, attr_name, default=None): + if hasattr(args, attr_name): + return getattr(args, attr_name) + else: + return default + + +def visualize_tensors(t, name=None, nest=0): + if name is not None: + print(name, "current nest: ", nest) + print("type: ", type(t)) + if 'dict' in str(type(t)): + print(t.keys()) + for k in t.keys(): + if t[k] is None: + print(k, "None") + else: + if 'Tensor' in str(type(t[k])): + print(k, t[k].shape) + elif 'dict' in str(type(t[k])): + print(k, 'dict') + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t[k])): + print(k, len(t[k])) + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t)): + print("list length: ", len(t)) + for t2 in t: + visualize_tensors(t2, name, nest + 1) + elif 'Tensor' in str(type(t)): + print(t.shape) + else: + print(t) + return "" + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + + + +class PhaLoss(nn.Module): + def __init__(self, epsilon=1e-8, norm='ortho'): + super(PhaLoss, self).__init__() + self.cri = nn.L1Loss() + self.epsilon = epsilon # To prevent undefined phase for zero magnitudes + self.norm = norm # Normalization for FFT + + def forward(self, x, y): + # Validate inputs + if not torch.isfinite(x).all() or not torch.isfinite(y).all(): + raise ValueError("Input contains NaN or Inf values") + + # Perform FFT + x_fft = torch.fft.rfft2(x, norm=self.norm) + y_fft = torch.fft.rfft2(y, norm=self.norm) + + # Compute phase + x_phase = torch.angle(x_fft) + y_phase = torch.angle(y_fft) + + # Compute L1 loss between phases + return self.cri(x_phase, y_phase) diff --git a/MRI_recon/code/Frequency-Diffusion/reproduce.sh b/MRI_recon/code/Frequency-Diffusion/reproduce.sh new file mode 100644 index 0000000000000000000000000000000000000000..bb212ecc711b35ee2a7a9b0e103e541a8ef9b618 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/reproduce.sh @@ -0,0 +1,35 @@ +conda create -n diffmri python=3.10 -y + +conda activate diffmri + + +pip install -r ./requirements.txt + + +# pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 \ + # --extra-index-url https://download.pytorch.org/whl/cu113 + + + +pip install torch==2.4.0+cu124 torchvision==0.19.0+cu124 torchaudio==2.4.0+cu124 \ + --extra-index-url https://download.pytorch.org/whl/cu124 + + + +pip install comet_ml torchgeometry albumentations +pip install --upgrade matplotlib +pip install --upgrade scikit-learn pytorch_msssim + + +pip install --upgrade pandas scipy scikit-image scikit-video scipy pytorch_lightning einops SimpleITK +pip uninstall h5py +pip install h5py fastmri torchmetrics +# torchmetric +pip install wandb==0.19 +pip install tensorboardX timm ml_collections + +# DR: Retinal Fundus Imaging +# OCT + Retinal Fundus Imaging + + + diff --git a/MRI_recon/code/Frequency-Diffusion/requirements.txt b/MRI_recon/code/Frequency-Diffusion/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..955346e0e023adaec83beaa0ac38ab8cbb7be7ac --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/requirements.txt @@ -0,0 +1,104 @@ +absl-py==1.1.0 +accelerate==0.11.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +antlr4-python3-runtime==4.9.3 +async-timeout==4.0.2 +attrs==21.4.0 +autopep8==1.6.0 +cachetools==5.2.0 +certifi==2022.6.15 +charset-normalizer==2.0.12 +click==8.1.3 +cycler==0.11.0 +Deprecated==1.2.13 +docker-pycreds==0.4.0 +einops==0.4.1 +einops-exts==0.0.3 +ema-pytorch==0.0.8 +fonttools==4.34.4 +frozenlist==1.3.0 +fsspec==2022.5.0 +ftfy==6.1.1 +future==0.18.2 +gitdb==4.0.9 +GitPython==3.1.27 +google-auth==2.9.0 +google-auth-oauthlib==0.4.6 +grpcio==1.47.0 +h5py==3.7.0 +humanize==4.2.2 +hydra-core==1.2.0 +idna==3.3 +imageio==2.19.3 +imageio-ffmpeg==0.4.7 +importlib-metadata==4.12.0 +importlib-resources==5.9.0 +joblib==1.1.0 +kiwisolver==1.4.3 +lxml==4.9.1 +Markdown==3.3.7 +matplotlib==3.5.2 +multidict==6.0.2 +networkx==2.8.5 +nibabel==4.0.1 +nilearn==0.9.1 +numpy==1.23.0 +oauthlib==3.2.0 +omegaconf==2.2.3 +pandas==1.4.3 +Pillow==9.1.1 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycodestyle==2.8.0 +pyDeprecate==0.3.1 +pydicom==2.3.0 +pytorch-lightning==1.9.2 +pytz==2022.1 +PyWavelets==1.3.0 +PyYAML==6.0 +pyzmq==19.0.2 +regex==2022.6.2 +requests==2.28.0 +requests-oauthlib==1.3.1 +rotary-embedding-torch==0.1.5 +rsa==4.8 +scikit-image==0.19.3 +scikit-learn==1.1.2 +scikit-video==1.1.11 +scipy==1.8.1 +seaborn==0.11.2 +sentry-sdk==1.7.2 +setproctitle==1.2.3 +shortuuid==1.0.9 +SimpleITK==2.1.1.2 +smmap==5.0.0 +tensorboard +tensorboard-data-server +tensorboard-plugin-wit +threadpoolctl==3.1.0 +tifffile==2022.8.3 +toml==0.10.2 +torch-tb-profiler==0.4.0 +torchio==0.18.80 +torchmetrics==0.9.1 +tqdm==4.64.0 +typing_extensions==4.2.0 +urllib3==1.26.9 +wandb==0.12.21 +Werkzeug==2.1.2 +wrapt==1.14.1 +yarl==1.7.2 +zipp==3.8.0 +tensorboardX==2.4.1 +protobuf==3.20.1 +sk-video +torchstat +timm +elasticdeform +opencv-python +monai[nibabel] +monai==0.9.0 +nibabel +ml-collections +glob2 diff --git a/MRI_recon/code/Frequency-Diffusion/test.py b/MRI_recon/code/Frequency-Diffusion/test.py new file mode 100644 index 0000000000000000000000000000000000000000..96bda5d729e48687f7a1c4ffe111363b3af3da9e --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/test.py @@ -0,0 +1,182 @@ +from diffusion_pytorch import GaussianDiffusion, Trainer, Model +from Fid import calculate_fid_given_samples +import torchvision +import os +import errno +import shutil +import argparse + + +def create_folder(path): + try: + os.mkdir(path) + except OSError as exc: + if exc.errno != errno.EEXIST: + raise + pass + + +def del_folder(path): + try: + shutil.rmtree(path) + except OSError as exc: + pass + + +create = 0 + +if create: + trainset = torchvision.datasets.CIFAR10( + root='./data', train=False, download=True) + root = './root_cifar10_test/' + del_folder(root) + create_folder(root) + + for i in range(10): + lable_root = root + str(i) + '/' + create_folder(lable_root) + + for idx in range(len(trainset)): + img, label = trainset[idx] + print(idx) + img.save(root + str(label) + '/' + str(idx) + '.png') + + +parser = argparse.ArgumentParser() +parser.add_argument('--time_steps', default=50, type=int) +parser.add_argument('--sample_steps', default=None, type=int) +parser.add_argument('--kernel_std', default=0.1, type=float) +parser.add_argument('--save_folder', default='progression_cifar', type=str) +parser.add_argument('--load_path', default='/cmlscratch/eborgnia/cold_diffusion/paper_defading_random_1/model.pt', type=str) +parser.add_argument('--data_path', default='./root_cifar10_test/', type=str) +parser.add_argument('--test_type', default='test_paper_showing_diffusion_images_diff', type=str) +parser.add_argument('--fade_routine', default='Random_Incremental', type=str) +parser.add_argument('--sampling_routine', default='x0_step_down', type=str) +parser.add_argument('--remove_time_embed', action="store_true") +parser.add_argument('--discrete', action="store_true") +parser.add_argument('--residual', action="store_true") + +args = parser.parse_args() +print(args) + +img_path=None +if 'train' in args.test_type: + img_path = args.data_path +elif 'test' in args.test_type: + img_path = args.data_path + +print("Img Path is ", img_path) + + + +image_channels = 1 + +if model_name == "unet": + model = Model(resolution=args.image_size, + in_channels=1, + out_ch=1, + ch=128, + ch_mult=(1, 2, 2, 2), + num_res_blocks=2, + attn_resolutions=(16,), + dropout=0.1).cuda() + +elif model_name == "twounet": + + model = TwoBranchNewModel(resolution=args.image_size, + in_channels=1, + out_ch=1, + ch=128, + ch_mult=(1, 2, 2, 2), + num_res_blocks=2, + attn_resolutions=(16,), + dropout=0.1).cuda() + +elif model_name == "twobranch": + downsample = [4, 4, 4] + disc_channels = 64 + disc_layers = 3 + discriminator_iter_start = 10000 + disc_loss_type = "hinge" + image_gan_weight = 1.0 + video_gan_weight = 1.0 + l1_weight = 4.0 + gan_feat_weight = 4.0 + perceptual_weight = 4.0 + i3d_feat = False + restart_thres = 1.0 + no_random_restart = False + norm_type = "group" + padding_type = "replicate" + num_groups = 32 + + base_num_every_group = 2 + num_features = 64 + act = "PReLU" + num_channels = 1 + + model = TwoBranchModel( + image_channels, + disc_channels, disc_layers, disc_loss_type, + gan_feat_weight, image_gan_weight, + discriminator_iter_start, + perceptual_weight, l1_weight, + num_features, act, base_num_every_group, num_channels + ).cuda() + + +diffusion = GaussianDiffusion( + diffusion_type, + model, + image_size=args.image_size, # Used to be 32 + channels=image_channels, + device_of_kernel='cuda', + timesteps=args.time_steps, + loss_type=args.loss_type, #$'l1', + kernel_std=args.kernel_std, + fade_routine=args.fade_routine, + sampling_routine=args.sampling_routine, + discrete=args.discrete +).cuda() + + +trainer = Trainer( + diffusion, + img_path, + image_size = 32, + train_batch_size = 32, + train_lr = 2e-5, + train_num_steps = 700000, # total training steps + gradient_accumulate_every = 2, # gradient accumulation steps + ema_decay = 0.995, # exponential moving average decay + fp16 = False, # turn on mixed precision training with apex + results_folder = args.save_folder, + load_path = args.load_path +) + + + + +if args.test_type == 'train_data': + trainer.test_from_data('train', s_times=args.sample_steps) + +elif args.test_type == 'test_data': + trainer.test_from_data('test', s_times=args.sample_steps) + +elif args.test_type == 'mixup_train_data': + trainer.test_with_mixup('train') + +elif args.test_type == 'mixup_test_data': + trainer.test_with_mixup('test') + +elif args.test_type == 'test_random': + trainer.test_from_random('random') + +elif args.test_type == 'test_fid_distance_decrease_from_manifold': + trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) + +elif args.test_type == 'test_paper_invert_section_images': + trainer.paper_invert_section_images() + +elif args.test_type == 'test_paper_showing_diffusion_images_diff': + trainer.paper_showing_diffusion_images() diff --git a/MRI_recon/code/Frequency-Diffusion/visualization/eroor_map_.py b/MRI_recon/code/Frequency-Diffusion/visualization/eroor_map_.py new file mode 100644 index 0000000000000000000000000000000000000000..91fde841b6020435360c0e773e88934c360d7b72 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/visualization/eroor_map_.py @@ -0,0 +1,49 @@ +""" +读取图像的ground truth和每个round重建结果, 并且绘制error map. +""" +import os +import numpy as np +from PIL import Image +from skimage import io +from matplotlib import pyplot as plt + + +def normalize_image(image): + # return (image - image.min())/(image.max() - image.min()) + image = image[40:200, 55:215] + # image = image[80:160, 95:175] + print("image shape:", image.shape) + return image/255.0 + + +def viz_diff_img(image, test_outputdir, image_name): + print("image range:", image.max(), image.min()) + plt.imshow(image, cmap='jet') + plt.savefig(os.path.join(test_outputdir, f'{image_name}'), + bbox_inches='tight') + + +root_dir = "/data/xiaohan/BRATS_dataset/image_100patients_unimodal/" +image_name = "BraTS20_Training_042_60_t1" + +dst_dir = "./recon_image_visualization" + + +img_gt = normalize_image(np.array(Image.open(root_dir + image_name + ".png"))) +img_in = normalize_image(np.array(Image.open(root_dir + image_name + "_10dB.png"))) + +img_round1 = normalize_image(np.array(Image.open(root_dir + image_name + "_10dB_krecon_round1.png"))) + +print(img_gt.max(), img_gt.min()) +print(img_in.max(), img_in.min()) +print(img_round1.max(), img_round1.min()) + +io.imsave(os.path.join(dst_dir, image_name + ".png"), img_gt) +io.imsave(os.path.join(dst_dir, image_name + "_10dB.png"), img_in) +io.imsave(os.path.join(dst_dir, image_name + "_10dB_round1.png"), img_round1) + +viz_diff_img(np.abs(img_gt - img_in)*255, dst_dir, image_name + "_input_error.png") +viz_diff_img(np.abs(img_gt - img_round1)*255, dst_dir, image_name + "_round1_error.png") + +print("input error:", np.mean(np.abs(img_gt - img_in))) +print("round1 error:", np.mean(np.abs(img_gt - img_round1))) \ No newline at end of file diff --git a/MRI_recon/code/Frequency-Diffusion/visualization/error_map.py b/MRI_recon/code/Frequency-Diffusion/visualization/error_map.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa3f468d86ff43154659f1d87e768896e0f0dd3 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/visualization/error_map.py @@ -0,0 +1,54 @@ +""" +读取图像的ground truth和每个round重建结果, 并且绘制error map. +""" +import os +import numpy as np +from PIL import Image +from skimage import io +from matplotlib import pyplot as plt + + +def normalize_image(image): + # return (image - image.min())/(image.max() - image.min()) + image = image[40:200, 55:215] + # image = image[80:160, 95:175] + print("image shape:", image.shape) + return image/255.0 + # return (image - image.min())/(image.max() - image.min()) + + +def viz_diff_img(image, test_outputdir, image_name): + print("image range:", image.max(), image.min()) + plt.axis('off') + # plt.imshow(image, cmap='jet',vmin=0, vmax=50) + plt.imshow(image, cmap='jet',vmin=0, vmax=30) + # plt.colorbar() + plt.savefig(os.path.join(test_outputdir, f'{image_name}'), bbox_inches='tight',pad_inches = 0) + +# baseline = 'UNet_4X' +# baseline_list = ['DCAMSR_4X', 'MCCA_4X', 'MINet_4X', 'MTrans_4X', 'swinir_4X_'] +baseline_list = ['DCAMSR_8X', 'MCCA_8X', 'MINet_8X', 'MTrans_8X', 'swinir_8X_'] +# baseline_list = ['swinir_8X_'] +baseline_list = ['our'] +for baseline in baseline_list: + # root_dir = f"/data/qic99/recon_code/recon_2M/BRATS_baseline/model/{baseline}/result_case/" + root_dir = '/data/qic99/recon_code/recon_2M/BRATS_freq_multi_fusion_2_2_4/model/unet_wo_kspace_4X_lr1e-4/result_case/' + # root_dir = '/data/qic99/recon_code/recon_2M/BRATS_freq_multi_fusion_2_2_4/model/unet_wo_kspace_8X_lr1e-4/result_case/' + image_name = "301_t2" + + dst_dir = "./error_map_8X" + # dst_dir = "./error_map_4X" + os.makedirs(dst_dir, exist_ok=True) + img_gt = normalize_image(np.array(Image.open(root_dir + image_name + ".png"))) + img_in = normalize_image(np.array(Image.open(root_dir + image_name + "_out.png"))) + img_lq = normalize_image(np.array(Image.open(root_dir + image_name + "_in.png"))) + + print(img_gt.max(), img_gt.min()) + print(img_in.max(), img_in.min()) + io.imsave(os.path.join(dst_dir, image_name + "_lq.png"), (img_lq*255).astype(np.uint8)) + io.imsave(os.path.join(dst_dir, image_name + ".png"), (img_gt*255).astype(np.uint8)) + io.imsave(os.path.join(dst_dir, baseline+'_'+image_name + "_out.png"), (img_in*255).astype(np.uint8)) + viz_diff_img(np.abs(img_gt - img_in)*255, dst_dir, baseline+'_'+image_name + "_error_map.png") + + print("input error:", np.mean(np.abs(img_gt - img_in))) + # break diff --git a/MRI_recon/code/Frequency-Diffusion/visualization/error_map2.py b/MRI_recon/code/Frequency-Diffusion/visualization/error_map2.py new file mode 100644 index 0000000000000000000000000000000000000000..62ab1dfff89d3d8f483ec948cb069ee44d0c70c0 --- /dev/null +++ b/MRI_recon/code/Frequency-Diffusion/visualization/error_map2.py @@ -0,0 +1,53 @@ +""" +读取图像的ground truth和每个round重建结果, 并且绘制error map. +""" +import os +import numpy as np +from PIL import Image +from skimage import io +from matplotlib import pyplot as plt + + +def normalize_image(image): + # return (image - image.min())/(image.max() - image.min()) + image = image[40:200, 55:215] + # image = image[80:160, 95:175] + print("image shape:", image.shape) + return image/255.0 + # return (image - image.min())/(image.max() - image.min()) + + +def viz_diff_img(image, test_outputdir, image_name): + print("image range:", image.max(), image.min()) + plt.axis('off') + # plt.imshow(image, cmap='jet',vmin=0, vmax=50) + plt.imshow(image, cmap='jet',vmin=0, vmax=80) + plt.savefig(os.path.join(test_outputdir, f'{image_name}'), bbox_inches='tight',pad_inches = 0) + +# baseline = 'UNet_4X' +baseline_list = ['DCAMSR_4x', 'MCCA_4x', 'MINet_4x', 'MTrans_4x', 'swinIR_4x'] +# baseline_list = ['DCAMSR_8X', 'MCCA_8X', 'MINet_8X', 'MTrans_8X', 'swinir_8X_'] +# baseline_list = ['swinir_8X_'] +baseline_list = ['our'] +for baseline in baseline_list: + # root_dir = f"/data/qic99/recon_code/recon_2M/fastMRI_baseline/model/{baseline}/result_case/" + root_dir = '/data/qic99/recon_code/recon_2M/BRATS_freq_multi_fusion_2_2_4/model/our_fastmri_4x/result_case/' + # root_dir = '/data/qic99/recon_code/recon_2M/BRATS_freq_multi_fusion_2_2_4/model/our_fastmri_8x/result_case/' + image_name = "file1001059_11" + + # dst_dir = "./fastMRI_error_map_8X" + dst_dir = "./fastMRI_error_map_4X" + os.makedirs(dst_dir, exist_ok=True) + img_gt = normalize_image(np.array(Image.open(root_dir + image_name + ".png"))) + img_in = normalize_image(np.array(Image.open(root_dir + image_name + "_out.png"))) + img_lq = normalize_image(np.array(Image.open(root_dir + image_name + "_in.png"))) + # breakpoint() + print(img_gt.max(), img_gt.min()) + print(img_in.max(), img_in.min()) + io.imsave(os.path.join(dst_dir, image_name + "_lq.png"), (img_lq).astype(np.uint8)) + io.imsave(os.path.join(dst_dir, image_name + ".png"), (img_gt).astype(np.uint8)) + io.imsave(os.path.join(dst_dir, baseline+'_'+image_name + "_out.png"), (img_in).astype(np.uint8)) + viz_diff_img(np.abs(img_gt - img_in), dst_dir, baseline+'_'+image_name + "_error_map.png") + + print("input error:", np.mean(np.abs(img_gt - img_in))) + # break diff --git a/MRI_recon/new_code/Frequency-Diffusion-main.zip b/MRI_recon/new_code/Frequency-Diffusion-main.zip new file mode 100644 index 0000000000000000000000000000000000000000..18d4542b49306d19dae6ee9b103706c5e019b86f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:795482a3e22f789046825ebdbab416543e65e87911151db2043d4d0a5f5f1912 +size 2050396 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/.gitignore b/MRI_recon/new_code/Frequency-Diffusion-main/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d1c3ac773ddd4815dcc056d9671582ad12f2295f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/.gitignore @@ -0,0 +1,142 @@ +# Generation results +results/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +log./ +log.txt +.log + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +*.png +*.pth +# Translations +*.mo +*.pot + +# Django stuff: +*.log +log/ +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ +.DS_Store +.idea/ +apex diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/README.md b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8f9aaadef4dd0210e6f11eb09f082c241e08051e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/README.md @@ -0,0 +1,97 @@ +# FSMNet +FSMNet efficiently explores global dependencies across different modalities. Specifically, the features for each modality are extracted by the Frequency-Spatial Feature Extraction (FSFE) module, featuring a frequency branch and a spatial branch. Benefiting from the global property of the Fourier transform, the frequency branch can efficiently capture global dependency with an image-size receptive field, while the spatial branch can extract local features. To exploit complementary information from the auxiliary modality, we propose a Cross-Modal Selective fusion (CMS-fusion) module that selectively incorporate the frequency and spatial features from the auxiliary modality to enhance the corresponding branch of the target modality. To further integrate the enhanced global features from the frequency branch and the enhanced local features from the spatial branch, we develop a Frequency-Spatial fusion (FS-fusion) module, resulting in a comprehensive feature representation for the target modality. + +

+ +## Paper + +Accelerated Multi-Contrast MRI Reconstruction via Frequency and Spatial Mutual Learning
+[Qi Chen](https://scholar.google.com/citations?user=4Q5gs2MAAAAJ&hl=en)1, [Xiaohan Xing](https://hathawayxxh.github.io/)2, *, [Zhen Chen](https://franciszchen.github.io/)3, [Zhiwei Xiong](http://staff.ustc.edu.cn/~zwxiong/)1
+1 University of Science and Technology of China,
+2 Stanford University,
+3 Centre for Artificial Intelligence and Robotics (CAIR), HKISI-CAS
+MICCAI, 2024
+[paper](http://arxiv.org/abs/2409.14113) | [code](https://github.com/qic999/FSMNet) | [huggingface](https://huggingface.co/datasets/qicq1c/MRI_Reconstruction) + +## 0. Installation + +```bash +git clone https://github.com/qic999/FSMNet.git +cd FSMNet +``` + +See [installation instructions](documents/INSTALL.md) to create an environment and obtain requirements. + +## 1. Prepare datasets +Download BraTS dataset and fastMRI dataset and save them to the `datapath` directory. +``` +cd $datapath +# download brats dataset +wget https://huggingface.co/datasets/qicq1c/MRI_Reconstruction/resolve/main/BRATS_100patients.zip +unzip BRATS_100patients.zip +# download fastmri dataset +wget https://huggingface.co/datasets/qicq1c/MRI_Reconstruction/resolve/main/singlecoil_train_selected.zip +unzip singlecoil_train_selected.zip +``` + +## 2. Training +##### BraTS dataset, AF=4 +``` +python train_brats.py --root_path /data/qic99/MRI_recon image_100patients_4X/ \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --MRIDOWN 4X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_4x +``` + +##### BraTS dataset, AF=8 +``` +python train_brats.py --root_path /data/qic99/MRI_recon/image_100patients_8X/ \ + --gpu 1 --batch_size 4 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_8x +``` + +##### fastMRI dataset, AF=4 +``` +python train_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x +``` + +##### fastMRI dataset, AF=8 +``` +python train_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 1 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x +``` + +## 3. Testing +##### BraTS dataset, AF=4 +``` +python test_brats.py --root_path /data/qic99/MRI_recon/image_100patients_4X/ \ + --gpu 3 --base_lr 0.0001 --MRIDOWN 4X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_4x --phase test +``` + +##### BraTS dataset, AF=8 +``` +python test_brats.py --root_path /data/qic99/MRI_recon/image_100patients_8X/ \ + --gpu 4 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_8x --phase test +``` + +##### fastMRI dataset, AF=4 +``` +python test_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 5 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x --phase test +``` + +##### fastMRI dataset, AF=8 +``` +python test_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 6 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x --phase test +``` \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/BRATS_DuDo_dataloader.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/BRATS_DuDo_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..e2b06691ee683a347d4a20948d03598db65e9c08 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/BRATS_DuDo_dataloader.py @@ -0,0 +1,295 @@ +""" +dual-domain network的dataloader, 读取两个模态的under-sampled和fully-sampled kspace data, 以及high-quality image作为监督信号。 +""" + + +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import cv2 +import os +import torch +from torch.utils.data import Dataset + +from .kspace_subsample import undersample_mri, mri_fft + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + + return data + + + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, HF_refine = 'False', split='train', MRIDOWN='4X', SNR=15, \ + transform=None, input_round = None, input_normalize=None): + + super().__init__() + self._base_dir = base_dir + self.HF_refine = HF_refine + self.input_round = input_round + self._MRIDOWN = MRIDOWN + self._SNR = SNR + self.im_ids = [] + self.t2_images = [] + self.splits_path = "/data/xiaohan/BRATS_dataset/cv_splits_100patients/" + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + # self.test_file = self.splits_path + 'train_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + # test_images = os.listdir(self._base_dir) + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + self.t2_images.append(t2_path) + + + self.transform = transform + self.input_normalize = input_normalize + + assert (len(self.t1_images) == len(self.t2_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + + image_name = self.t1_images[index].split('t1')[0] + # print("image name:", image_name) + + t1 = np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0 + t2 = np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0 + # print("images:", t1_in.shape, t1.shape, t2_in.shape, t2.shape) + # print("t1 before standardization:", t1.max(), t1.min(), t1.mean()) + # print("loaded t1 range:", t1.max(), t1.min()) + # print("loaded t2 range:", t2.max(), t2 .min()) + + ### normalize the MRI image by divide_max + t1_max, t2_max = t1.max(), t2.max() + t1 = t1/t1_max + t2 = t2/t2_max + sample_stats = {"t1_max": t1_max, "t2_max": t2_max, "image_name": image_name} + + # sample_stats = {"t1_max": 1.0, "t2_max": 1.0} + + ### convert images to kspace and perform undersampling. + t1_kspace_in, t1_in, t1_kspace, t1_img = mri_fft(t1, _SNR = self._SNR) + t2_kspace_in, t2_in, t2_kspace, t2_img, mask = undersample_mri( + t2, _MRIDOWN = self._MRIDOWN, _SNR = self._SNR) + + + # print("loaded t2 range:", t2.max(), t2.min()) + # print("t2_under_img range:", t2_under_img.max(), t2_under_img.min()) + # print("t2_kspace real_part range:", t2_kspace.real.max(), t2_kspace.real.min()) + # print("t2_kspace imaginary_part range:", t2_kspace.imag.max(), t2_kspace.imag.min()) + # print("t2_kspace_in real_part range:", t2_kspace_in.real.max(), t2_kspace_in.real.min()) + # print("t2_kspace_in imaginary_part range:", t2_kspace_in.imag.max(), t2_kspace_in.imag.min()) + + if self.HF_refine == "False": + sample = {'t1': t1_img, 't1_in': t1_in, 't1_kspace': t1_kspace, 't1_kspace_in': t1_kspace_in, \ + 't2': t2_img, 't2_in': t2_in, 't2_kspace': t2_kspace, 't2_kspace_in': t2_kspace_in, \ + 't2_mask': mask} + + elif self.HF_refine == "True": + ### 读取上一步重建的kspace data. + t1_krecon_path = self._base_dir + self.t1_images[index].replace( + 't1.png', 't1_' + str(self._SNR) + 'dB_recon_kspace_' + self.input_round + '_DudoLoss.npy') + t2_krecon_path = self._base_dir + self.t1_images[index].replace('t1.png', 't2_' + self._MRIDOWN + \ + '_' + str(self._SNR) + 'dB_recon_kspace_' + self.input_round + '_DudoLoss.npy') + + t1_krecon = np.load(t1_krecon_path) + t2_krecon = np.load(t2_krecon_path) + # print("t1 and t2 recon kspace:", t1_krecon.shape, t2_krecon.shape) + # + sample = {'t1': t1_img, 't1_in': t1_in, 't1_kspace': t1_kspace, 't1_kspace_in': t1_kspace_in, \ + 't2': t2_img, 't2_in': t2_in, 't2_kspace': t2_kspace, 't2_kspace_in': t2_kspace_in, \ + 't2_mask': mask, 't1_krecon': t1_krecon, 't2_krecon': t2_krecon} + + + if self.transform is not None: + sample = self.transform(sample) + + return sample, sample_stats + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + # print("img_in before_numpy range:", img_in.max(), img_in.min()) + img = torch.from_numpy(img).float() + target = torch.from_numpy(target).float() + # print("img_in range:", img_in.max(), img_in.min()) + + return {'ct': img, 'mri': target} + + +# class ToTensor(object): +# """Convert ndarrays in sample to Tensors.""" + +# def __call__(self, sample): +# # swap color axis because +# # numpy image: H x W x C +# # torch image: C X H X W +# img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) +# img = sample['image'][:, :, None].transpose((2, 0, 1)) +# target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) +# target = sample['target'][:, :, None].transpose((2, 0, 1)) +# # print("img_in before_numpy range:", img_in.max(), img_in.min()) +# img_in = torch.from_numpy(img_in).float() +# img = torch.from_numpy(img).float() +# target_in = torch.from_numpy(target_in).float() +# target = torch.from_numpy(target).float() +# # print("img_in range:", img_in.max(), img_in.min()) + +# return {'ct_in': img_in, +# 'ct': img, +# 'mri_in': target_in, +# 'mri': target} diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/BRATS_dataloader.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/BRATS_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..28d52593311a0bbe1c679fc0687cbe949e85dc7c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/BRATS_dataloader.py @@ -0,0 +1,174 @@ +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import os +import torch +from torch.utils.data import Dataset + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', MRIDOWN='4X', SNR=15, transform=None): + + super().__init__() + self._base_dir = base_dir + self._MRIDOWN = MRIDOWN + self.im_ids = [] + self.t2_images = [] + self.t1_undermri_images, self.t2_undermri_images = [], [] + self.splits_path = "/home/xiaohan/datasets/BRATS_dataset/BRATS_2020_images/cv_splits/" + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + # test_images = os.listdir(self._base_dir) + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + if SNR == 0: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_undermri') + t1_under_path = image_path + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + else: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + t1_under_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + if MRIDOWN == "False": + t2_under_path = image_path.replace('t1', 't2_' + str(SNR) + 'dB') + else: + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + + # print("image paths:", image_path, t1_under_path, t2_path, t2_under_path) + + self.t2_images.append(t2_path) + self.t1_undermri_images.append(t1_under_path) + self.t2_undermri_images.append(t2_under_path) + + # print("t1 images:", self.t1_images) + # print("t2 images:", self.t2_images) + # print("t1_undermri_images:", self.t1_undermri_images) + # print("t2_undermri_images:", self.t2_undermri_images) + + self.transform = transform + + assert (len(self.t1_images) == len(self.t2_images)) + assert (len(self.t1_images) == len(self.t1_undermri_images)) + assert (len(self.t1_images) == len(self.t2_undermri_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + ### 两种settings. + ### 1. T1 fully-sampled 不加noise, T2 down-sampled, 做MRI acceleration. + ### 2. T1 fully-sampled 但是加noise, T2 down-sampled同时也加noise, 同时做MRI acceleration and enhancement. + ### T1, T2两个模态的输入都是low-quality images. + sample = {'image_in': np.array(Image.open(self._base_dir + self.t1_undermri_images[index]))/255.0, + 'image': np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0, + 'target_in': np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0, + 'target': np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0} + + + # ### 2023/05/23, Xiaohan, 把T1模态的输入改成high-quality图像(和ground truth一致,看能否为T2提供更好的guidance)。 + # sample = {'image_in': np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0, + # 'image': np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0, + # 'target_in': np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0, + # 'target': np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0} + + if self.transform is not None: + sample = self.transform(sample) + + return sample + + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 256, 256 + crop_size = 240 + pad_size = (256-240)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + # print("img_in:", img_in.shape) + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + return {'ct_in': img_in, + 'ct': img, + 'mri_in': target_in, + 'mri': target} diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/BRATS_dataloader_new.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/BRATS_dataloader_new.py new file mode 100644 index 0000000000000000000000000000000000000000..beafa08c756ac25592ae2ba4fd0f673278d274c5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/BRATS_dataloader_new.py @@ -0,0 +1,392 @@ +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import cv2 +import os +import torch +from torch.utils.data import Dataset +from torchvision import transforms +from .albu_transform import get_albu_transforms + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', MRIDOWN='4X', \ + SNR=15, transform=None, input_normalize=None, use_kspace=False): + + super().__init__() + self._base_dir = base_dir + self._MRIDOWN = MRIDOWN + self.im_ids = [] + self.t2_images = [] + self.t1_undermri_images, self.t2_undermri_images = [], [] + self.t1_krecon_images, self.t2_krecon_images = [], [] + self.kspace_refine = "False" # ADD + self.use_kspace = use_kspace + + self.albu_transforms = get_albu_transforms(split, (240, 240), + shift_limit=0.1, + scale_limit=(-0.1, 0.1), + rotate_limit=5, + distort_limit=0.15, # + elastic_alpha=1, elastic_sigma=2 + ) + + + + + name = base_dir.rstrip("/ ").split('/')[-1] + print("base_dir=", base_dir, ", folder name =", name) + self.splits_path = base_dir.replace(name, 'cv_splits_100patients/') + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + + if SNR == 0: + t1_under_path = image_path + + if self.kspace_refine == "False": + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + elif self.kspace_refine == "True": + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_krecon') + + if self.kspace_refine == "False": + t1_krecon_path = image_path + t2_krecon_path = image_path + + # if SNR == 0: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_undermri') + t1_under_path = image_path + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + + else: + t1_under_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB') + + t1_krecon_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_krecon_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB') + + + self.t2_images.append(t2_path) + self.t1_undermri_images.append(t1_under_path) + self.t2_undermri_images.append(t2_under_path) + + self.t1_krecon_images.append(t1_krecon_path) + self.t2_krecon_images.append(t2_krecon_path) + + self.transform = transform + self.input_normalize = input_normalize + + assert (len(self.t1_images) == len(self.t2_images)) + assert (len(self.t1_images) == len(self.t1_undermri_images)) + assert (len(self.t1_images) == len(self.t2_undermri_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + + t1_in = np.array(Image.open(self._base_dir + self.t1_undermri_images[index]))/255.0 + t1 = np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0 + t1_krecon = np.array(Image.open(self._base_dir + self.t1_krecon_images[index]))/255.0 + + t2_in = np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0 + t2 = np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0 + t2_krecon = np.array(Image.open(self._base_dir + self.t2_krecon_images[index]))/255.0 + + if self.input_normalize == "mean_std": + t1_in, t1_mean, t1_std = normalize_instance(t1_in, eps=1e-11) + t1 = normalize(t1, t1_mean, t1_std, eps=1e-11) + t2_in, t2_mean, t2_std = normalize_instance(t2_in, eps=1e-11) + t2 = normalize(t2, t2_mean, t2_std, eps=1e-11) + + t1_krecon = normalize(t1_krecon, t1_mean, t1_std, eps=1e-11) + t2_krecon = normalize(t2_krecon, t2_mean, t2_std, eps=1e-11) + + ### clamp input to ensure training stability. + + v = 10 + t1_in = np.clip(t1_in, -v, v) + t1 = np.clip(t1, -v, v) + t2_in = np.clip(t2_in, -v, v) + t2 = np.clip(t2, -v, v) + + t1_krecon = np.clip(t1_krecon, -6, 6) + t2_krecon = np.clip(t2_krecon, -6, 6) + + sample_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + elif self.input_normalize == "min_max": + t1_in = (t1_in - t1_in.min())/(t1_in.max() - t1_in.min()) + t1 = (t1 - t1.min())/(t1.max() - t1.min()) + t2_in = (t2_in - t2_in.min())/(t2_in.max() - t2_in.min()) + t2 = (t2 - t2.min())/(t2.max() - t2.min()) + sample_stats = 0 + + + elif self.input_normalize == "divide": + sample_stats = 0 + + + sample = self.albu_transforms(image=t1_in, image2=t1, + image3=t2_in, image4=t2, + image5=t1_krecon, image6=t2_krecon) + + sample = {'image_in': sample['image'].astype(float), + 'image': sample['image2'].astype(float), + 'image_krecon': sample['image5'].astype(float), + 'target_in': sample['image3'].astype(float), + 'target': sample['image4'].astype(float), + 'target_krecon': sample['image6'].astype(float)} + + + + + if self.transform is not None: + sample = self.transform(sample) + + return sample, sample_stats + + + +def add_gaussian_noise(img, mean=0, std=1): + noise = std * torch.randn_like(img) + mean + noisy_img = img + noise + return torch.clamp(noisy_img, 0, 1) + + + +class AddNoise(object): + def __call__(self, sample): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + add_gauss_noise = transforms.GaussianBlur(kernel_size=5) + add_poiss_noise = transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)) + + add_noise = transforms.RandomApply([add_gauss_noise, add_poiss_noise], p=0.5) + + img_in = add_noise(img_in) + target_in = add_noise(target_in) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + + return sample + + + + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 256, 256 + crop_size = 240 + pad_size = (256-240)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_krecon = sample['image_krecon'] + target_krecon = sample['target_krecon'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + img_krecon = np.pad(img_krecon, pad_size, mode='reflect') + target_krecon = np.pad(target_krecon, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + # print("img_in:", img_in.shape) + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + img_krecon = img_krecon[ww:ww+crop_size, hh:hh+crop_size] + target_krecon = target_krecon[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'image_krecon': img_krecon, \ + 'target_in': target_in, 'target': target, 'target_krecon': target_krecon} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + + +class RandomFlip(object): + def __call__(self, sample): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + img_krecon = sample['image_krecon'] + target_krecon = sample['target_krecon'] + + # horizontal flip + if random.random() < 0.5: + img_in = cv2.flip(img_in, 1) + img = cv2.flip(img, 1) + target_in = cv2.flip(target_in, 1) + target = cv2.flip(target, 1) + + # vertical flip + if random.random() < 0.5: + img_in = cv2.flip(img_in, 0) + img = cv2.flip(img, 0) + target_in = cv2.flip(target_in, 0) + target = cv2.flip(target, 0) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target, 'image_krecon': img_krecon, 'target_krecon': target_krecon} + return sample + + + + +class RandomRotate(object): + def __call__(self, sample, center=None, scale=1.0): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + degrees = [0, 90, 180, 270] + angle = random.choice(degrees) + + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + + img_in = cv2.warpAffine(img_in, matrix, (w, h)) + img = cv2.warpAffine(img, matrix, (w, h)) + target_in = cv2.warpAffine(target_in, matrix, (w, h)) + target = cv2.warpAffine(target, matrix, (w, h)) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + + image_krecon = sample['image_krecon'][:, :, None].transpose((2, 0, 1)) + target_krecon = sample['target_krecon'][:, :, None].transpose((2, 0, 1)) + + # print("img_in before_numpy range:", img_in.max(), img_in.min()) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + image_krecon = torch.from_numpy(image_krecon).float() + target_krecon = torch.from_numpy(target_krecon).float() + # print("img_in range:", img_in.max(), img_in.min()) + + return {'image_in': img_in, + 'image': img, + 'target_in': target_in, + 'target': target, + 'image_krecon': image_krecon, + 'target_krecon': target_krecon} diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/BRATS_kspace_dataloader.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/BRATS_kspace_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..871a153b20eac89e45ec0025e2aa31476360fde0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/BRATS_kspace_dataloader.py @@ -0,0 +1,298 @@ +""" +Load the low-quality and high-quality images from the BRATS dataset and transform to kspace. +""" + + +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import cv2 +import os +import torch +from torch.utils.data import Dataset + +from .kspace_subsample import undersample_mri, mri_fft + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + + return data + + + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', MRIDOWN='4X', SNR=15, transform=None, input_normalize=None): + + super().__init__() + self._base_dir = base_dir + self._MRIDOWN = MRIDOWN + self.im_ids = [] + self.t2_images = [] + self.t1_undermri_images, self.t2_undermri_images = [], [] + self.splits_path = "/data/xiaohan/BRATS_dataset/cv_splits_100patients/" + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + # self.test_file = self.splits_path + 'train_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + # test_images = os.listdir(self._base_dir) + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + if SNR == 0: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_undermri') + t1_under_path = image_path + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + else: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + t1_under_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + + self.t2_images.append(t2_path) + self.t1_undermri_images.append(t1_under_path) + self.t2_undermri_images.append(t2_under_path) + + # print("t1 images:", self.t1_images) + # print("t2 images:", self.t2_images) + # print("t1_undermri_images:", self.t1_undermri_images) + # print("t2_undermri_images:", self.t2_undermri_images) + + self.transform = transform + self.input_normalize = input_normalize + + assert (len(self.t1_images) == len(self.t2_images)) + assert (len(self.t1_images) == len(self.t1_undermri_images)) + assert (len(self.t1_images) == len(self.t2_undermri_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + + # t1_in = np.array(Image.open(self._base_dir + self.t1_undermri_images[index]))/255.0 + t1 = np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0 + # t2_in = np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0 + t2 = np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0 + # print("images:", t1_in.shape, t1.shape, t2_in.shape, t2.shape) + # print("t1 before standardization:", t1.max(), t1.min(), t1.mean()) + # print("t1 range:", t1.max(), t1.min()) + # print("t2 range:", t2.max(), t2 .min()) + + if self.input_normalize == "mean_std": + ### 对input image和target image都做(x-mean)/std的归一化操作 + t1, t1_mean, t1_std = normalize_instance(t1, eps=1e-11) + t2, t2_mean, t2_std = normalize_instance(t2, eps=1e-11) + + ### clamp input to ensure training stability. + t1 = np.clip(t1, -6, 6) + t2 = np.clip(t2, -6, 6) + # print("t1 after standardization:", t1.max(), t1.min(), t1.mean()) + + sample_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + elif self.input_normalize == "min_max": + # t1 = (t1 - t1.min())/(t1.max() - t1.min()) + # t2 = (t2 - t2.min())/(t2.max() - t2.min()) + t1 = t1/t1.max() + t2 = t2/t2.max() + sample_stats = 0 + + elif self.input_normalize == "divide": + sample_stats = 0 + + + ### convert images to kspace and perform undersampling. + # t1_kspace, t1_masked_kspace, t1_img, t1_under_img = undersample_mri(t1, _MRIDOWN = None) + t1_kspace, t1_img = mri_fft(t1) + t2_kspace, t2_masked_kspace, t2_img, t2_under_img, mask = undersample_mri(t2, _MRIDOWN = self._MRIDOWN) + + + sample = {'t1': t1_img, 't2': t2_img, 'under_t2': t2_under_img, "t2_mask": mask, \ + 't1_kspace': t1_kspace, 't2_kspace': t2_kspace, 't2_masked_kspace': t2_masked_kspace} + + if self.transform is not None: + sample = self.transform(sample) + + return sample, sample_stats + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + # print("img_in before_numpy range:", img_in.max(), img_in.min()) + img = torch.from_numpy(img).float() + target = torch.from_numpy(target).float() + # print("img_in range:", img_in.max(), img_in.min()) + + return {'ct': img, 'mri': target} + + +# class ToTensor(object): +# """Convert ndarrays in sample to Tensors.""" + +# def __call__(self, sample): +# # swap color axis because +# # numpy image: H x W x C +# # torch image: C X H X W +# img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) +# img = sample['image'][:, :, None].transpose((2, 0, 1)) +# target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) +# target = sample['target'][:, :, None].transpose((2, 0, 1)) +# # print("img_in before_numpy range:", img_in.max(), img_in.min()) +# img_in = torch.from_numpy(img_in).float() +# img = torch.from_numpy(img).float() +# target_in = torch.from_numpy(target_in).float() +# target = torch.from_numpy(target).float() +# # print("img_in range:", img_in.max(), img_in.min()) + +# return {'ct_in': img_in, +# 'ct': img, +# 'mri_in': target_in, +# 'mri': target} diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/__init__.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c3f98c8fd581dcd081a7ada0bc91184eec8aea8 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/__init__.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/albu_transform.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/albu_transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5863128cc61413f2d088368f5300befddb712d3d Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/albu_transform.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/fastmri.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/fastmri.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f6f4c778e6e8dad93f5b3626f9392b565865645 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/fastmri.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/math.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/math.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd4495a0193392b488fad5b9339d2b677dfcc6eb Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/math.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/subsample.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/subsample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3be9b6de1bde21fcdabc5ac347d41f6f6036a8d Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/subsample.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/transforms.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4233d1e09ea59e3bbbf226e9bea6ed33510dc2f Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/__pycache__/transforms.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/albu_transform.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/albu_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..3a044e9d1abbe17f34a2a644ffe96c43e938dc43 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/albu_transform.py @@ -0,0 +1,84 @@ +# -*- encoding: utf-8 -*- +#Time :2022/02/24 18:14:15 +#Author :Hao Chen +#FileName :trans_lib.py +#Version :2.0 + +import cv2 +import torch +import numpy as np +import albumentations as A + + +def get_albu_transforms(type="train", + img_size = (192, 192), + shift_limit=0.2, + scale_limit=(-0.2, 0.2), + rotate_limit=5, + distort_limit=0.3, + elastic_alpha=2, elastic_sigma=5): + if type == 'train': + compose = [ + # A.VerticalFlip(p=0.5), + # A.HorizontalFlip(p=0.5), + + A.ShiftScaleRotate(shift_limit=shift_limit, scale_limit=scale_limit, + rotate_limit=rotate_limit, p=0.5), + + A.OneOf([ + A.GridDistortion(num_steps=1, distort_limit=distort_limit, p=1.0), + A.ElasticTransform(alpha=elastic_alpha, sigma=elastic_sigma, p=1.0) + + ], p=0.5), + + + + A.Resize(img_size[0], img_size[1])] + else: + compose = [A.Resize(img_size[0], img_size[1])] + + return A.Compose(compose, p=1.0, additional_targets={'image2': 'image', + 'image3': 'image', + 'image4': 'image', + 'image5': 'image', + 'image6': 'image', + "mask2": "mask"}) + + + + +# Beta function +def gamma_concern(img, gamma): + mean = torch.mean(img) + + img = (img - mean) * gamma + img = img + mean + img = torch.clip(img, 0, 1) + + return img + +def gamma_power(img, gamma, direction=0): + if direction == 1: + img = 1 - img + img = torch.pow(img, gamma) + + img = img / torch.max(img) + if direction == 1: + img = 1 - img + + return img + +def gamma_exp(img, gamma, direction=0): + if direction == 1: + img = 1 - img + + img = torch.exp(img * gamma) + img = img / torch.max(img) + + if direction == 1: + img = 1 - img + return img + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/brats_12_kspace_mask.npy b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/brats_12_kspace_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..b859a0423a0e9e5805c9ff4ef64a0c51737f38a3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/brats_12_kspace_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62b353be27d6dae1e1e0b3f68615d3d77e1e16098a17af15e660c8ed91c34a83 +size 1152128 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/brats_4X_mask.npy b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/brats_4X_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..bdf32304f95640286541ceb1068582dc69b0d60a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/brats_4X_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76341ba680a0bc9c80389e01f8511e5bd99ab361eeb48d83516904b84cccc518 +size 460928 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/brats_8X_mask.npy b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/brats_8X_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..c389e708adeb3307db90ff071599256b8f59dab5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/brats_8X_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c5160add079e8f4dc2496e5ef87c110015026d9f6116329da2238a73d8bc104 +size 230528 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/brats_8_kspace_mask.npy b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/brats_8_kspace_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..95cd5d8e9f54a955178b53b4313ef3a229c53d3e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/brats_8_kspace_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bbceb8d0c07b0936c34fc750e86343cee96e34eb429d72d3dede93488b3b737f +size 1152128 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/brats_data_gen.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/brats_data_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..b14c5361c534a67edc6a9fef311fce4f7f45fda4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/brats_data_gen.py @@ -0,0 +1,302 @@ +""" +Xiaohan Xing, 2023/04/08 +对BRATS 2020数据集进行Pre-processing, 得到各个模态的under-sampled input image和2d groung-truth. +""" +import os +import argparse +import numpy as np +import nibabel as nib +from scipy import ndimage as nd +from scipy import ndimage +from skimage import filters +from skimage import io +import torch +import torch.fft +from matplotlib import pyplot as plt + +MRIDOWN=2 +SNR = 35 + + +class MaskFunc_Cartesian: + """ + MaskFunc creates a sub-sampling mask of a given shape. + The mask selects a subset of columns from the input k-space data. If the k-space data has N + columns, the mask picks out: + a) N_low_freqs = (N * center_fraction) columns in the center corresponding to + low-frequencies + b) The other columns are selected uniformly at random with a probability equal to: + prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). + This ensures that the expected number of columns selected is equal to (N / acceleration) + It is possible to use multiple center_fractions and accelerations, in which case one possible + (center_fraction, acceleration) is chosen uniformly at random each time the MaskFunc object is + called. + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there + is a 50% probability that 4-fold acceleration with 8% center fraction is selected and a 50% + probability that 8-fold acceleration with 4% center fraction is selected. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be retained. + If multiple values are provided, then one of these numbers is chosen uniformly + each time. + accelerations (List[int]): Amount of under-sampling. This should have the same length + as center_fractions. If multiple values are provided, then one of these is chosen + uniformly each time. An acceleration of 4 retains 25% of the columns, but they may + not be spaced evenly. + """ + if len(center_fractions) != len(accelerations): + raise ValueError('Number of center fractions should match number of accelerations') + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random.RandomState() + + def __call__(self, shape, seed=None): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The shape should have + at least 3 dimensions. Samples are drawn along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting the seed + ensures the same mask is generated each time for the same shape. + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError('Shape should have 3 or more dimensions') + + self.rng.seed(seed) + num_cols = shape[-2] + + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + # Create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs + 1e-10) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad:pad + num_low_freqs] = True + + # Reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + mask = mask.repeat(shape[0], 1, 1) + + return mask + + +## mri related +def mri_fourier_transform_2d(image, mask): + ''' + image: input tensor [B, H, W, C] + mask: mask tensor [H, W] + ''' + spectrum = torch.fft.fftn(image, dim=(1, 2)) + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + spectrum = torch.fft.fftshift(spectrum, dim=(1, 2)) + # Downsample k-space + spectrum = spectrum * mask[None, :, :, None] + return spectrum + +## mri related +def mri_inver_fourier_transform_2d(spectrum): + ''' + image: input tensor [B, H, W, C] + ''' + spectrum = torch.fft.ifftshift(spectrum, dim=(1, 2)) + image = torch.fft.ifftn(spectrum, dim=(1, 2)) + + return image + + +def simulate_undersample_mri(raw_mri): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + ff = MaskFunc_Cartesian([0.2], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4 + shape = [240, 240, 1] + mask = ff(shape, seed=1337) + mask = mask[:, :, 0] + # print("original MRI:", mri) + + # print("original MRI:", mri.shape) + kspace = mri_fourier_transform_2d(mri, mask) + kspace = add_gaussian_noise(kspace) + mri_recon = mri_inver_fourier_transform_2d(kspace) + kdata = torch.sqrt(kspace.real ** 2 + kspace.imag ** 2 + 1e-10) + kdata = kdata.data.numpy()[0, :, :, 0] + + under_img = torch.sqrt(mri_recon.real ** 2 + mri_recon.imag ** 2) + under_img = under_img.data.numpy()[0, :, :, 0] + + return under_img, kspace + + +def add_gaussian_noise(img, snr=15): + ### 根据SNR确定noise的放大比例 + num_pixels = img.shape[0]*img.shape[1]*img.shape[2]*img.shape[3] + psr = torch.sum(torch.abs(img.real)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + + noise_r = torch.randn_like(img.real)*np.sqrt(pnr) + + psim = torch.sum(torch.abs(img.imag)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = torch.randn_like(img.imag)*np.sqrt(pnim) + + noise = noise_r + 1j*noise_im + noise_img = img + noise + # print("original image:", img) + # print("gaussian noise:", noise) + + return noise_img + + +def complexsing_addnoise(img, snr): + ### add noise to the real part of the image. + img_numpy = img.cpu().numpy() + # print("kspace data:", img) + s_r = np.real(img_numpy) + num_pixels = s_r.shape[0]*s_r.shape[1]*s_r.shape[2]*s_r.shape[3] + psr = np.sum(np.abs(s_r)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + # print("PSR:", psr, "PNR:", pnr) + noise_r = np.random.randn(num_pixels)*np.sqrt(pnr) + + ### add noise to the iamginary part of the image. + s_im = np.imag(img_numpy) + psim = np.sum(np.abs(s_im)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = np.random.randn(num_pixels)*np.sqrt(pnim) + + noise = torch.Tensor(noise_r) + 1j*torch.Tensor(noise_im) + sn = img + noise + # print("noisy data:", sn) + # sn = torch.Tensor(sn) + + return sn + + + +def _parse(rootdir): + filetree = {} + + for sample_file in os.listdir(rootdir): + sample_dir = rootdir + sample_file + subject = sample_file + + for filename in os.listdir(sample_dir): + modality = filename.split('.').pop(0).split('_')[-1] + + if subject not in filetree: + filetree[subject] = {} + filetree[subject][modality] = filename + + return filetree + + + +def clean(rootdir, savedir, source_modality, target_modality): + filetree = _parse(rootdir) + print("filetree:", filetree) + + if not os.path.exists(savedir+'/img_norm'): + os.makedirs(savedir+'/img_norm') + + for subject, modalities in filetree.items(): + print(f'{subject}:') + + if source_modality not in modalities or target_modality not in modalities: + print('-> incomplete') + continue + + source_path = os.path.join(rootdir, subject, modalities[source_modality]) + target_path = os.path.join(rootdir, subject, modalities[target_modality]) + + source_image = nib.load(source_path) + target_image = nib.load(target_path) + + source_volume = source_image.get_fdata() + target_volume = target_image.get_fdata() + source_binary_volume = np.zeros_like(source_volume) + target_binary_volume = np.zeros_like(target_volume) + + print("source volume:", source_volume.shape) + print("target volume:", target_volume.shape) + + for i in range(source_binary_volume.shape[-1]): + source_slice = source_volume[:, :, i] + target_slice = target_volume[:, :, i] + + if source_slice.min() == source_slice.max(): + print("invalide source slice") + source_binary_volume[:, :, i] = np.zeros_like(source_slice) + else: + source_binary_volume[:, :, i] = ndimage.morphology.binary_fill_holes( + source_slice > filters.threshold_li(source_slice)) + + if target_slice.min() == target_slice.max(): + print("invalide target slice") + target_binary_volume[:, :, i] = np.zeros_like(target_slice) + else: + target_binary_volume[:, :, i] = ndimage.morphology.binary_fill_holes( + target_slice > filters.threshold_li(target_slice)) + + source_volume = np.where(source_binary_volume, source_volume, np.ones_like( + source_volume) * source_volume.min()) + target_volume = np.where(target_binary_volume, target_volume, np.ones_like( + target_volume) * target_volume.min()) + + ## resize + if source_image.header.get_zooms()[0] < 0.6: + scale = np.asarray([240, 240, source_volume.shape[-1]]) / np.asarray(source_volume.shape) + source_volume = nd.zoom(source_volume, zoom=scale, order=3, prefilter=False) + target_volume = nd.zoom(target_volume, zoom=scale, order=0, prefilter=False) + + # save volume into images + source_volume = (source_volume-source_volume.min())/(source_volume.max()-source_volume.min()) + target_volume = (target_volume-target_volume.min())/(target_volume.max()-target_volume.min()) + + for i in range(source_binary_volume.shape[-1]): + source_binary_slice = source_binary_volume[:, :, i] + target_binary_slice = target_binary_volume[:, :, i] + if source_binary_slice.max() > 0 and target_binary_slice.max() > 0: + dd = target_volume.shape[0] // 2 + target_slice = target_volume[dd - 120:dd + 120, dd - 120:dd + 120, i] + source_slice = source_volume[dd - 120:dd + 120, dd - 120:dd + 120, i] + print("source slice range:", source_slice.shape) + print("target slice range:", target_slice.max(), target_slice.min()) + # undersample MRI + source_under_img, source_kspace = simulate_undersample_mri(source_slice) + target_under_img, target_kspace = simulate_undersample_mri(target_slice) + + # # io.imsave(savedir+'/img_norm/'+subject+'_'+str(i)+'_'+source_modality+'.png', (source_slice * 255.0).astype(np.uint8)) + io.imsave(savedir + '/img_temp/' + subject + '_' + str(i) + '_' + source_modality + '_' + str(MRIDOWN) + 'X_' + str(SNR) + 'dB_undermri.png', + (source_under_img * 255.0).astype(np.uint8)) + + # io.imsave(savedir + '/img_temp/' + subject + '_' + str(i) + '_' + source_modality + '_' + str(MRIDOWN) + 'X_undermri.png', + # (source_under_img * 255.0).astype(np.uint8)) + # # io.imsave(savedir+'/img_norm/'+subject+'_'+str(i)+'_'+target_modality+'.png', (target_slice * 255.0).astype(np.uint8)) + # io.imsave(savedir + '/img_norm/' + subject + '_' + str(i) + '_' + target_modality + '_' + str(MRIDOWN) + 'X_undermri.png', + # (target_under_img * 255.0).astype(np.uint8)) + + # np.savez_compressed(rootdir + '/img_norm/' + subject + '_' + str(i) + '_' + target_modality + '_raw_'+str(MRIDOWN)+'X'+str(CTNVIEW)+'P', + # kspace=kspace, under_t1=under_img, + # t1=source_slice, ct=target_slice) + + +def main(args): + clean(args.rootdir,args.savedir, args.source, args.target) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--rootdir', type=str, default='/home/xiaohan/datasets/BRATS_dataset/BRATS_2020/') + parser.add_argument('--savedir', type=str, default='/home/xiaohan/datasets/BRATS_dataset/BRATS_2020_images/') + parser.add_argument('--source', default='t1') + parser.add_argument('--target', default='t2') + + main(parser.parse_args()) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/kspace_12_mask.npy b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/kspace_12_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..689cdeb8373100cb53c73db6a1f78176737667ec --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/kspace_12_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4be9178cd1a6a9456a19be09a98d0b3ecb862907b1835388ae8097ef2df01321 +size 2048128 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/kspace_4_mask.npy b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/kspace_4_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..9ac77fa44d98099a5c07948465d1c0096de38828 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/kspace_4_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f68ba364235a51534884b434ac3a1c16d0cf263b9e4c08c5b3757214a6f78216 +size 2048128 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/kspace_8_mask.npy b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/kspace_8_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..f0217bbe6f62b18296c488807e2d8a90ac7f0118 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/kspace_8_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f7397d527311ac6ba09ee2621d2f964e276a16b4bf0aaded163653abef882bb +size 2048128 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/m4raw_4_mask.npy b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/m4raw_4_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..bc50082cfed244305fcbc0b19fa0c7dafec851ef --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/m4raw_4_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:94d140842f0e44fbf73655a49e48308b694f232864a483ad45985bf94a82977c +size 1152128 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/m4raw_8_mask.npy b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/m4raw_8_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..8b5d65e7b18b698e5a071f63d78da2fda193d6a4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/example_mask/m4raw_8_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5f65a474271059263f57c0b3c342622da34a498f62739b0b997f92e8dcebf68 +size 4096128 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/fastmri.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/fastmri.py new file mode 100644 index 0000000000000000000000000000000000000000..af7dd9ff18d6f2a70e98f5cd50938d1d5cad9fd9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/fastmri.py @@ -0,0 +1,362 @@ +import csv +import os +import random +import xml.etree.ElementTree as etree +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +import pathlib + +import h5py +import numpy as np +import torch +import yaml +from torch.utils.data import Dataset +from .transforms import build_transforms +from matplotlib import pyplot as plt + +from .albu_transform import get_albu_transforms + +def fetch_dir(key, data_config_file=pathlib.Path("fastmri_dirs.yaml")): + """ + Data directory fetcher. + + This is a brute-force simple way to configure data directories for a + project. Simply overwrite the variables for `knee_path` and `brain_path` + and this function will retrieve the requested subsplit of the data for use. + + Args: + key (str): key to retrieve path from data_config_file. + data_config_file (pathlib.Path, + default=pathlib.Path("fastmri_dirs.yaml")): Default path config + file. + + Returns: + pathlib.Path: The path to the specified directory. + """ + if not data_config_file.is_file(): + default_config = dict( + knee_path="/home/jc3/Data/", + brain_path="/home/jc3/Data/", + ) + with open(data_config_file, "w") as f: + yaml.dump(default_config, f) + + raise ValueError(f"Please populate {data_config_file} with directory paths.") + + with open(data_config_file, "r") as f: + data_dir = yaml.safe_load(f)[key] + + data_dir = pathlib.Path(data_dir) + + if not data_dir.exists(): + raise ValueError(f"Path {data_dir} from {data_config_file} does not exist.") + + return data_dir + + +def et_query( + root: etree.Element, + qlist: Sequence[str], + namespace: str = "http://www.ismrm.org/ISMRMRD", +) -> str: + """ + ElementTree query function. + This can be used to query an xml document via ElementTree. It uses qlist + for nested queries. + Args: + root: Root of the xml to search through. + qlist: A list of strings for nested searches, e.g. ["Encoding", + "matrixSize"] + namespace: Optional; xml namespace to prepend query. + Returns: + The retrieved data as a string. + """ + s = "." + prefix = "ismrmrd_namespace" + + ns = {prefix: namespace} + + for el in qlist: + s = s + f"//{prefix}:{el}" + + value = root.find(s, ns) + if value is None: + raise RuntimeError("Element not found") + + return str(value.text) + + + + + + +import albumentations as A + + + +class SliceDataset(Dataset): + def __init__( + self, + root, + transform, + challenge, + sample_rate=1, + mode='train' + ): + self.mode = mode + self.albu_transforms = get_albu_transforms(self.mode, (320, 320), + shift_limit=0.03, + scale_limit=(-0.03, 0.03), + rotate_limit=2, + distort_limit=0.03, # + elastic_alpha=1, + elastic_sigma=1 + ) + + + # challenge + if challenge not in ("singlecoil", "multicoil"): + raise ValueError('challenge should be either "singlecoil" or "multicoil"') + self.recons_key = ( + "reconstruction_esc" if challenge == "singlecoil" else "reconstruction_rss" + ) + # transform + self.transform = transform + + self.examples = [] + + self.cur_path = root + + + if not os.path.exists(self.cur_path): + self.cur_path = self.cur_path + "_selected" + + print("cur_path:", self.cur_path) + + + self.csv_file = "knee_data_split/singlecoil_" + self.mode + "_split_less.csv" + + with open(self.csv_file, 'r') as f: + reader = csv.reader(f) + + id = 0 + + for row in reader: + pd_metadata, pd_num_slices = self._retrieve_metadata(os.path.join(self.cur_path, row[0] + '.h5')) + + pdfs_metadata, pdfs_num_slices = self._retrieve_metadata(os.path.join(self.cur_path, row[1] + '.h5')) + + for slice_id in range(min(pd_num_slices, pdfs_num_slices)): + self.examples.append( + (os.path.join(self.cur_path, row[0] + '.h5'), os.path.join(self.cur_path, row[1] + '.h5') + , slice_id, pd_metadata, pdfs_metadata, id)) + id += 1 + + if sample_rate < 1: + random.shuffle(self.examples) + num_examples = round(len(self.examples) * sample_rate) + + self.examples = self.examples[0:num_examples] + + def __len__(self): + return len(self.examples) + + def __getitem__(self, i): + + # read pd + pd_fname, pdfs_fname, slice, pd_metadata, pdfs_metadata, id = self.examples[i] + + with h5py.File(pd_fname, "r") as hf: + pd_kspace = hf["kspace"][slice] + + pd_mask = np.asarray(hf["mask"]) if "mask" in hf else None + + pd_target = hf[self.recons_key][slice] if self.recons_key in hf else None + + attrs = dict(hf.attrs) + + attrs.update(pd_metadata) + + if self.transform is None: + pd_sample = (pd_kspace, pd_mask, pd_target, attrs, pd_fname, slice) + else: + pd_sample = self.transform(pd_kspace, pd_mask, pd_target, attrs, pd_fname, slice) + + with h5py.File(pdfs_fname, "r") as hf: + pdfs_kspace = hf["kspace"][slice] + pdfs_mask = np.asarray(hf["mask"]) if "mask" in hf else None + + pdfs_target = hf[self.recons_key][slice] if self.recons_key in hf else None + + attrs = dict(hf.attrs) + + attrs.update(pdfs_metadata) + + if self.transform is None: + pdfs_sample = (pdfs_kspace, pdfs_mask, pdfs_target, attrs, pdfs_fname, slice) + else: + pdfs_sample = self.transform(pdfs_kspace, pdfs_mask, pdfs_target, attrs, pdfs_fname, slice) + + # 0: input, 1: target, 2: mean, 3: std + sample = self.albu_transforms(image=pdfs_sample[1].numpy(), + image2=pd_sample[1].numpy(), + image3=pdfs_sample[0].numpy(), + image4=pd_sample[0].numpy()) + + pdfs_sample = list(pdfs_sample) + pd_sample = list(pd_sample) + pdfs_sample[1] = sample['image'] + pd_sample[1] = sample['image2'] + pdfs_sample[0] = sample['image3'] + pd_sample[0] = sample['image4'] + + # dataset pdf mean and std tensor(3.1980e-05) tensor(1.3093e-05) + # print("dataset pdf mean and std", pdfs_sample[2], pdfs_sample[3]) + # print(pdfs_sample[1].shape, pdfs_sample[1].min(), pdfs_sample[1].max()) + + return (pd_sample, pdfs_sample, id) + + + + + def _retrieve_metadata(self, fname): + with h5py.File(fname, "r") as hf: + et_root = etree.fromstring(hf["ismrmrd_header"][()]) + + enc = ["encoding", "encodedSpace", "matrixSize"] + enc_size = ( + int(et_query(et_root, enc + ["x"])), + int(et_query(et_root, enc + ["y"])), + int(et_query(et_root, enc + ["z"])), + ) + rec = ["encoding", "reconSpace", "matrixSize"] + recon_size = ( + int(et_query(et_root, rec + ["x"])), + int(et_query(et_root, rec + ["y"])), + int(et_query(et_root, rec + ["z"])), + ) + + lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"] + enc_limits_center = int(et_query(et_root, lims + ["center"])) + enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1 + + padding_left = enc_size[1] // 2 - enc_limits_center + padding_right = padding_left + enc_limits_max + + num_slices = hf["kspace"].shape[0] + + metadata = { + "padding_left": padding_left, + "padding_right": padding_right, + "encoding_size": enc_size, + "recon_size": recon_size, + } + + return metadata, num_slices + + +def build_dataset(args, mode='train', sample_rate=1, use_kspace=False): + assert mode in ['train', 'val', 'test'], 'unknown mode' + transforms = build_transforms(args, mode, use_kspace) + + return SliceDataset(os.path.join(args.root_path, 'singlecoil_' + mode), transforms, 'singlecoil', sample_rate=sample_rate, mode=mode) + + +if __name__ == "__main__": + ## make logger file + from torch.utils.data import DataLoader + from option import args + import time + from frequency_diffusion.degradation.k_degradation import get_ksu_kernel, apply_ksu_kernel, apply_tofre, \ + apply_to_spatial + + batch_size = 1 + db_train = build_dataset(args, mode='train') + + trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + + for i_batch, sampled_batch in enumerate(trainloader): + time2 = time.time() + # print("time for data loading:", time2 - time1) + + pd, pdfs, _ = sampled_batch + target = pdfs[1] + + mean = pdfs[2] + std = pdfs[3] + + pd_img = pd[1].unsqueeze(1) + pdfs_img = pdfs[0].unsqueeze(1) + target = target.unsqueeze(1) + + b = pd_img.size(0) + + pd_img = pd_img # [4, 1, 320, 320] + pdfs_img = pdfs_img # [4, 1, 320, 320] + target = target # [4, 1, 320, 320] + + # ----------- Degradation ------------- + num_timesteps = 1 + image_size = 320 + + # Output a list of k-space kernels + kspace_masks = get_ksu_kernel(num_timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=args.ACCELERATIONS[0], + ) # args.ACCELERATIONS = [4] or [8] + kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + + + + t = torch.randint(0, num_timesteps, (b,)).long() + mask = kspace_masks[t] + fft, mask = apply_tofre(target.clone(), mask) + # fft = fft * mask + 0.0 + pdfs_img = apply_to_spatial(fft) + pdfs_img_mask = apply_to_spatial(mask * fft)[0] + + + + + print("mask = ", mask.shape, mask.min(), mask.max()) + print("pdfs_img_mask =", pdfs_img_mask.shape) + + import matplotlib.pyplot as plt + + # combine them together + pd_img = pd_img.squeeze(1).cpu().numpy() + pdfs_img = pdfs_img.squeeze(1).cpu().numpy() + target = target.squeeze(1).cpu().numpy() + + plt.figure() + + plt.subplot(161) + plt.imshow(pd_img[0], cmap='gray') + plt.title('PD') + plt.axis('off') + plt.subplot(162) + + plt.imshow(pdfs_img_mask[0], cmap='gray') + plt.title('PDFS_mask') + plt.axis('off') + + plt.subplot(163) + plt.imshow(pdfs_img[0], cmap='gray') + plt.title('PDFS') + plt.axis('off') + + plt.subplot(164) + plt.imshow(pdfs_img_mask[0] - target[0], cmap='gray') + plt.title('Diff') + plt.axis('off') + + plt.subplot(165) + plt.imshow(target[0], cmap='gray') + plt.title('Target') + plt.axis('off') + + plt.subplot(166) + plt.imshow(pdfs_img[0] - target[0], cmap='gray')#mask[0][0], cmap='gray') + plt.title('Target') + plt.axis('off') + + plt.show() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/hybrid_sparse.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/hybrid_sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..42e4a7e33c2204c13a1c4509897baf19e1fb07f1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/hybrid_sparse.py @@ -0,0 +1,156 @@ +from __future__ import print_function, division +import numpy as np +from glob import glob +import random +from skimage import transform + +import torch +from torch.utils.data import Dataset + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', transform=None): + + super().__init__() + self._base_dir = base_dir + self.im_ids = [] + self.images = [] + self.gts = [] + + if split=='train': + self._image_dir = self._base_dir + imagelist = glob(self._image_dir+"/*_ct.png") + imagelist=sorted(imagelist) + for image_path in imagelist: + gt_path = image_path.replace('ct', 't1') + self.images.append(image_path) + self.gts.append(gt_path) + + + elif split=='test': + self._image_dir = self._base_dir + imagelist = glob(self._image_dir + "/*_ct.png") + imagelist=sorted(imagelist) + for image_path in imagelist: + gt_path = image_path.replace('ct', 't1') + self.images.append(image_path) + self.gts.append(gt_path) + + self.transform = transform + + assert (len(self.images) == len(self.gts)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.images))) + + def __len__(self): + return len(self.images) + + + def __getitem__(self, index): + img_in, img, target_in, target= self._make_img_gt_point_pair(index) + sample = {'image_in': img_in, 'image':img, 'target_in': target_in, 'target': target} + # print("image in:", img_in.shape) + + if self.transform is not None: + sample = self.transform(sample) + + return sample + + + def _make_img_gt_point_pair(self, index): + # Read Image and Target + + # the default setting (i.e., rawdata.npz) is 4X64P + dd = np.load(self.images[index].replace('.png', '_raw_4X64P.npz')) + # print("images range:", dd['fbp'].max(), dd['ct'].max(), dd['under_t1'].max(), dd['t1'].max()) + _img_in = dd['fbp'] + _img_in[_img_in>0.6]=0.6 + _img_in = _img_in/0.6 + + _img = dd['ct'] + _img =(_img/1000*0.192+0.192) + _img[_img<0.0]=0.0 + _img[_img>0.6]=0.6 + _img = _img/0.6 + + _target_in = dd['under_t1'] + _target = dd['t1'] + + return _img_in, _img, _target_in, _target + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 400, 400 + crop_size = 384 + pad_size = (400-384)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + return {'ct_in': img_in, + 'ct': img, + 'mri_in': target_in, + 'mri': target} diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/kspace_subsample.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/kspace_subsample.py new file mode 100644 index 0000000000000000000000000000000000000000..49da4fa5e508df325a98767e46725e93c9be0445 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/kspace_subsample.py @@ -0,0 +1,328 @@ +""" +2023/10/16, +preprocess kspace data with the undersampling mask in the fastMRI project. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +## mri related +def mri_fourier_transform_2d(image, mask): + ''' + image: input tensor [B, H, W, C] + mask: mask tensor [H, W] + ''' + spectrum = torch.fft.fftn(image, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + spectrum = torch.fft.fftshift(spectrum, dim=(1, 2)) + # Downsample k-space + masked_spectrum = spectrum * mask[None, :, :, None] + return spectrum, masked_spectrum + + +## mri related +def mri_inver_fourier_transform_2d(spectrum): + ''' + image: input tensor [B, H, W, C] + ''' + spectrum = torch.fft.ifftshift(spectrum, dim=(1, 2)) + image = torch.fft.ifftn(spectrum, dim=(1, 2), norm='ortho') + + return image + + +def add_gaussian_noise(kspace, snr): + ### 根据SNR确定noise的放大比例 + num_pixels = kspace.shape[0]*kspace.shape[1]*kspace.shape[2]*kspace.shape[3] + psr = torch.sum(torch.abs(kspace.real)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + noise_r = torch.randn_like(kspace.real)*np.sqrt(pnr) + + psim = torch.sum(torch.abs(kspace.imag)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = torch.randn_like(kspace.imag)*np.sqrt(pnim) + + noise = noise_r + 1j*noise_im + noisy_kspace = kspace + noise + + return noisy_kspace + + +def mri_fft(raw_mri, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + spectrum = torch.fft.fftn(mri, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + kspace = torch.fft.fftshift(spectrum, dim=(1, 2)) + + if _SNR > 0: + noisy_kspace = add_gaussian_noise(kspace, _SNR) + else: + noisy_kspace = kspace + + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + + return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ + kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1) + + +from dataloaders.math import complex_abs, complex_abs_numpy, complex_abs_sq + +def mri_fft_m4raw(lq_mri, hq_mri): + # breakpoint() + lq_mri = torch.tensor(lq_mri[0])[None, :, :, None].to(torch.float32) + lq_mri_spectrum = torch.fft.fftn(lq_mri, dim=(1, 2), norm='ortho') + lq_mri_spectrum = torch.fft.fftshift(lq_mri_spectrum, dim=(1, 2)) + + # Complex + lq_mri = mri_inver_fourier_transform_2d(lq_mri_spectrum[0]) + # print("lq_mri shape:", lq_mri.shape) + lq_mri = torch.cat([torch.real(lq_mri), torch.imag(lq_mri)], dim=-1) + lq_mri = complex_abs(lq_mri) + lq_mri = torch.abs(lq_mri) + # print("lq_mri after shape:", lq_mri.shape) + lq_mri = lq_mri.unsqueeze(-1) + # + lq_kspace = torch.cat([torch.real(lq_mri_spectrum), torch.imag(lq_mri_spectrum)], dim=-1) + lq_kspace = torch.abs(complex_abs(lq_kspace[0])) + lq_kspace = lq_kspace.unsqueeze(-1) + + hq_mri = torch.tensor(hq_mri[0])[None, :, :, None].to(torch.float32) + hq_mri_spectrum = torch.fft.fftn(hq_mri, dim=(1, 2), norm='ortho') + hq_mri_spectrum = torch.fft.fftshift(hq_mri_spectrum, dim=(1, 2)) + + hq_mri = mri_inver_fourier_transform_2d(hq_mri_spectrum[0]) + hq_mri = torch.cat([torch.real(hq_mri), torch.imag(hq_mri)], dim=-1) + + hq_mri = complex_abs(hq_mri) # Convert the complex number to the absolute value. + hq_mri = torch.abs(hq_mri) + hq_mri = hq_mri.unsqueeze(-1) + # + hq_kspace = torch.cat([torch.real(hq_mri_spectrum), torch.imag(hq_mri_spectrum)], dim=-1) + + hq_kspace = torch.abs(complex_abs(hq_kspace[0])) + hq_kspace = hq_kspace.unsqueeze(-1) + + # breakpoint() + return lq_kspace, lq_mri.permute(2, 0, 1), \ + hq_kspace, hq_mri.permute(2, 0, 1) + + +def undersample_mri(raw_mri, _MRIDOWN, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + if _MRIDOWN == "4X": + mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 + elif _MRIDOWN == "8X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + + ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + + shape = [240, 240, 1] + mask = ff(shape, seed=1337) + mask = mask[:, :, 0] # [1, 240] + # print("mask:", mask.shape) + # print("original MRI:", mri) + + # print("original MRI:", mri.shape) + ### under-sample the kspace data. + kspace, masked_kspace = mri_fourier_transform_2d(mri, mask) + ### add low-field noise to the kspace data. + if _SNR > 0: + noisy_kspace = add_gaussian_noise(masked_kspace, _SNR) + else: + noisy_kspace = masked_kspace + + ### conver the corrupted kspace data back to noisy MRI image. + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + + return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ + kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1), mask.unsqueeze(-1) + + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/m4_utils.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/m4_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0d76246b0c8fb757b6b589b7f889e0316696d0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/m4_utils.py @@ -0,0 +1,272 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +import numpy as np + +def complex_mul(x, y): + """ + Complex multiplication. + + This multiplies two complex tensors assuming that they are both stored as + real arrays with the last dimension being the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == y.shape[-1] == 2 + re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] + im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] + + return torch.stack((re, im), dim=-1) + + +def complex_conj(x): + """ + Complex conjugate. + + This applies the complex conjugate assuming that the input array has the + last dimension as the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == 2 + + return torch.stack((x[..., 0], -x[..., 1]), dim=-1) + + +# def fft2c(data): +# """ +# Apply centered 2 dimensional Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The FFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# data = torch.fft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + +# def ifft2c(data): +# """ +# Apply centered 2-dimensional Inverse Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The IFFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# # data = torch.ifft(data, 2, normalized=True) +# data = torch.fft.ifft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + + +def fft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2 dimensional Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.fft``. + + Returns: + The FFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.fftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.ifft``. + + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + + + + +def complex_abs(data): + """ + Compute the absolute value of a complex valued input tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Absolute value of data. + """ + assert data.size(-1) == 2 + + return (data ** 2).sum(dim=-1).sqrt() + + + +def complex_abs_numpy(data): + assert data.shape[-1] == 2 + + return np.sqrt(np.sum(data ** 2, axis=-1)) + + +def complex_abs_sq(data):#multi coil + """ + Compute the squared absolute value of a complex tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Squared absolute value of data. + """ + assert data.size(-1) == 2 + return (data ** 2).sum(dim=-1) + + +# Helper functions + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data + """ + data = data.numpy() + return data[..., 0] + 1j * data[..., 1] diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/m4raw_dataloader.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/m4raw_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..bb2553d67f2e1d680fef57d89598fd02222c6cb8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/m4raw_dataloader.py @@ -0,0 +1,606 @@ + +from __future__ import print_function, division +from typing import Dict, NamedTuple, Optional, Sequence, Tuple, Union +import sys +sys.path.append('.') +from glob import glob +import os, time +os.environ['OPENBLAS_NUM_THREADS'] = '1' +import numpy as np +import torch +from torch.utils.data import Dataset + +import h5py +from matplotlib import pyplot as plt +from dataloaders.math import ifft2c, fft2c, complex_abs +from dataloaders.kspace_subsample import create_mask_for_mask_type + +import argparse +from torch.utils.data import DataLoader +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity +from dataloaders.kspace_subsample import undersample_mri, mri_fft, mri_fft_m4raw +from tqdm import tqdm +import h5py + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + +def normal(x): + y = np.zeros_like(x) + for i in range(y.shape[0]): + x_min = x[i].min() + x_max = x[i].max() + y[i] = (x[i] - x_min)/(x_max-x_min) + return y + + + +def undersample_mri(kspace, _MRIDOWN): + # print("kspace shape:", kspace.shape) ## [18, 4, 256, 256, 2] + + if _MRIDOWN == "4X": + mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 + elif _MRIDOWN == "8X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + + ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + + shape = [256, 256, 1] + mask = ff(shape, seed=1337) ## [1, 256, 1] + mask = mask[:, :, 0] # [1, 256] + + masked_kspace = kspace * mask[None, None, :, :, None] + + return masked_kspace, mask.unsqueeze(-1) + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Args: + data (np.array): Input numpy array. + + Returns: + torch.Tensor: PyTorch version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data) + +def rss(data, dim=0): + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Args: + data (torch.Tensor): The input tensor + dim (int): The dimensions along which to apply the RSS transform + + Returns: + torch.Tensor: The RSS value. + """ + return torch.sqrt((data ** 2).sum(dim)) + +def read_h5(file_name, _MRIDOWN='None', use_kspace=False): + hf = h5py.File(file_name) + volume_kspace = hf['kspace'][()] + slice_kspace = volume_kspace + slice_kspace2 = to_tensor(slice_kspace) + + slice_image = ifft2c(slice_kspace2) + slice_image_abs = complex_abs(slice_image) + slice_image_rss = rss(slice_image_abs, dim=1) + slice_image_rss = np.abs(slice_image_rss.numpy()) + slice_image_rss = normal(slice_image_rss) + + if _MRIDOWN == 'None' or use_kspace: + masked_image_rss = slice_image_rss + + else: + # print("Undersample MRI") + # Undersample MRI + masked_kspace, mask = undersample_mri(slice_kspace2, _MRIDOWN) # Masked + + masked_image = ifft2c(masked_kspace) + masked_image_abs = complex_abs(masked_image) + masked_image_rss = rss(masked_image_abs, dim=1) + masked_image_rss = np.abs(masked_image_rss.numpy()) + masked_image_rss = normal(masked_image_rss) + + return slice_image_rss, masked_image_rss + + +DEBUG = True + +class M4Raw_TrainSet(Dataset): + def __init__(self, root_path, MRIDOWN, kspace_refine='False', use_kspace=False, + save_h5=True, h5_path=""): + + self.use_kspace = use_kspace + self.kspace_refine = kspace_refine + start_time = time.time() + + input_list1 = sorted(glob(os.path.join(root_path + '/multicoil_train' + '/*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(root_path + '/multicoil_train' +'/*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + T2_input_list = [input_list1, input_list2, input_list3] + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 256, 256]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 256, 256]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 256, 256]) + + """ + 读取kspace network重建的图像 + """ + if kspace_refine == 'True': + krecon_list1 = sorted(glob(os.path.join(root_path + 'multicoil_train' + '/*_T102_recon_kspace_round2_images.npy'))) + krecon_list2 = [path.replace('_T102','_T101') for path in krecon_list1] + krecon_list3 = [path.replace('_T102','_T103') for path in krecon_list1] + T1_krecon_list = [krecon_list1, krecon_list2, krecon_list3] + + krecon_list1 = sorted(glob(os.path.join(root_path + 'multicoil_train' + '/*_T202_recon_kspace_round2_images.npy'))) + krecon_list2 = [path.replace('_T202','_T201') for path in krecon_list1] + krecon_list3 = [path.replace('_T202','_T203') for path in krecon_list1] + T2_krecon_list = [krecon_list1, krecon_list2, krecon_list3] + + self.T1_krecon_list = T1_krecon_list + self.T2_krecon_list = T2_krecon_list + + self.T1_krecon = np.zeros([len(input_list1), len(T1_krecon_list), 18, 240, 240]).astype(np.float32) + self.T2_krecon = np.zeros([len(input_list2), len(T2_krecon_list), 18, 240, 240]).astype(np.float32) + + + + print('TrainSet loading...') + if save_h5 and not os.path.exists(h5_path): + for i in tqdm(range(len(self.T1_input_list))): + for j, path in enumerate(T1_input_list[i]): + self.T1_images[j][i], _ = read_h5(path, use_kspace=use_kspace) + # self.fname_slices[i].append(path) # each coil + + if kspace_refine == 'True': + for k, path in enumerate(T1_krecon_list[i]): + self.T1_krecon[k][i] = np.load(path).astype(np.float32)/255.0 + self.T1_labels = np.mean(self.T1_images, axis=1) # multi-coil mean + + for i in tqdm(range(len(self.T2_input_list))): + for j, path in enumerate(T2_input_list[i]): + self.T2_images[j][i], self.T2_masked_images[j][i] = read_h5(path, _MRIDOWN=MRIDOWN, use_kspace=use_kspace) + if kspace_refine == 'True': + for k, path in enumerate(T2_krecon_list[i]): + self.T2_krecon[k][i] = np.load(path).astype(np.float32)/255.0 + + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print(f'Finish loading with time = {time.time() - start_time}s') + + + N, _, S, H, W = self.T1_images.shape + self.fname_slices = [] + + for i in range(N): + for j in range(S): + self.fname_slices.append((i, j)) + # print(f'nan value at {i}, {j}, {k}, {l}') + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),256,256)[:, :, 8:248, 8:248] + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),256,256)[:, :, 8:248, 8:248] + self.T1_labels = self.T1_labels.reshape(-1,1,256,256)[:, :, 8:248, 8:248] + self.T2_labels = self.T2_labels.reshape(-1,1,256,256)[:, :, 8:248, 8:248] + + if kspace_refine == 'True': + self.T1_krecon = self.T1_krecon.transpose(0,2,1,3,4).reshape(-1,len(T1_krecon_list),240,240) + self.T2_krecon = self.T2_krecon.transpose(0,2,1,3,4).reshape(-1,len(T2_krecon_list),240,240) + + if save_h5: + with h5py.File(h5_path, 'w') as f: + f.create_dataset('T1_images', data=self.T1_images) + f.create_dataset('T2_images', data=self.T2_images) + f.create_dataset('T1_labels', data=self.T1_labels) + f.create_dataset('T2_labels', data=self.T2_labels) + f.create_dataset('fname_slices', data=np.array(self.fname_slices, dtype=np.int32)) # , dtype='S' + + if kspace_refine == 'True': + f.create_dataset('T1_krecon', data=self.T1_krecon) + f.create_dataset('T2_krecon', data=self.T2_krecon) + + print("saved h5 file to", h5_path) + + else: + # read it back + with h5py.File(h5_path, 'r') as f: + self.T1_images = f['T1_images'][()] + self.T2_images = f['T2_images'][()] + self.T1_labels = f['T1_labels'][()] + self.T2_labels = f['T2_labels'][()] + self.fname_slices = f['fname_slices'][()] + + if kspace_refine == 'True': + self.T1_krecon = f['T1_krecon'][()] + self.T2_krecon = f['T2_krecon'][()] + + print("read h5 file from", h5_path) + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + T1_images = self.T1_images[idx] # lq_mri + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] # gt_mri + T2_labels = self.T2_labels[idx] + + fname = self.fname_slices[idx][0] + slice = self.fname_slices[idx][1] + + ## 每次都是从三个repetition中选择一个作为input. + choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) + T1_images = T1_images[choices] + T2_images = T2_images[choices] + + t1_kspace_in, t1_in, t1_kspace, t1_img = mri_fft_m4raw(T1_images, T1_labels) + t2_kspace_in, t2_in, t2_kspace, t2_img = mri_fft_m4raw(T2_images, T2_labels) + + + # normalize + t1_img, t1_mean, t1_std = normalize_instance(t1_img) + t1_in = normalize(t1_in, t1_mean, t1_std) + + t2_img, t2_mean, t2_std = normalize_instance(t2_img) + t2_in = normalize(t2_in, t2_mean, t2_std) + + + + # filter value that greater or less than 6 + v = 10 + t1_img = torch.clamp(t1_img, -v, v) + t2_img = torch.clamp(t2_img, -v, v) + t1_in = torch.clamp(t1_in, -v, v) + t2_in = torch.clamp(t2_in, -v, v) + + + # How to get mean and std of the training data? + # fname, slice + sample = { + 'fname': fname, + 'slice': slice, + + 'ref_kspace_full': t1_kspace, + 'ref_kspace_sub': t1_kspace_in, + 'ref_image_full': t1_img, + 'ref_image_sub': t1_in, + 't1_mean': t1_mean, + 't1_std': t1_std, + + 'tag_kspace_full': t2_kspace, + 'tag_kspace_sub': t2_kspace_in, + 'tag_image_full': t2_img, + 'tag_image_sub': t2_in, + 't2_mean': t2_mean, + 't2_std': t2_std, + + } + + return sample + + + +class M4Raw_TestSet(Dataset): + def __init__(self, root_path, MRIDOWN, kspace_refine='False', use_kspace=False, + save_h5=True, h5_path=""): + + self.kspace_refine = kspace_refine + + input_list1 = sorted(glob(os.path.join(root_path + '/multicoil_val' + '/*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(root_path + '/multicoil_val' + '/*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + T2_input_list = [input_list1,input_list2,input_list3] + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 256, 256]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 256, 256]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 256, 256]) + + + """ + 读取kspace network重建的图像 + """ + if kspace_refine == 'True': + krecon_list1 = sorted(glob(os.path.join(root_path + 'multicoil_val' + '/*_T102_recon_kspace_round2_images.npy'))) + krecon_list2 = [path.replace('_T102','_T101') for path in krecon_list1] + krecon_list3 = [path.replace('_T102','_T103') for path in krecon_list1] + T1_krecon_list = [krecon_list1, krecon_list2, krecon_list3] + + krecon_list1 = sorted(glob(os.path.join(root_path + 'multicoil_val' + '/*_T202_recon_kspace_round2_images.npy'))) + krecon_list2 = [path.replace('_T202','_T201') for path in krecon_list1] + krecon_list3 = [path.replace('_T202','_T203') for path in krecon_list1] + T2_krecon_list = [krecon_list1, krecon_list2, krecon_list3] + + self.T1_krecon_list = T1_krecon_list + self.T2_krecon_list = T2_krecon_list + + self.T1_krecon = np.zeros([len(input_list1), len(T1_krecon_list), 18, 240, 240]).astype(np.float32) + self.T2_krecon = np.zeros([len(input_list2), len(T2_krecon_list), 18, 240, 240]).astype(np.float32) + + + print('TestSet loading...') + if save_h5 and not os.path.exists(h5_path): + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + self.T1_images[j][i], _ = read_h5(path, use_kspace=use_kspace) + + if kspace_refine == 'True': + for k, path in enumerate(T1_krecon_list[i]): + self.T1_krecon[k][i] = np.load(path).astype(np.float32)/255.0 + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_images[j][i], self.T2_masked_images[j][i] = read_h5(path, _MRIDOWN = MRIDOWN, use_kspace=use_kspace) + + if kspace_refine == 'True': + for k, path in enumerate(T2_krecon_list[i]): + self.T2_krecon[k][i] = np.load(path).astype(np.float32)/255.0 + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print('Finish loading') + N, _, S, H, W = self.T1_images.shape + self.fname_slices = [] + + for i in range(N): + for j in range(S): + self.fname_slices.append((i, j)) + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),256,256)[:, :, 8:248, 8:248] + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),256,256)[:, :, 8:248, 8:248] + self.T1_labels = self.T1_labels.reshape(-1,1,256,256)[:, :, 8:248, 8:248] + self.T2_labels = self.T2_labels.reshape(-1,1,256,256)[:, :, 8:248, 8:248] + print("Test data shape:", self.T1_images.shape) + + + if save_h5: + with h5py.File(h5_path, 'w') as f: + f.create_dataset('T1_images', data=self.T1_images) + f.create_dataset('T2_images', data=self.T2_images) + f.create_dataset('T1_labels', data=self.T1_labels) + f.create_dataset('T2_labels', data=self.T2_labels) + f.create_dataset('fname_slices', data=np.array(self.fname_slices, dtype=np.int32)) # , dtype='S' + + if kspace_refine == 'True': + f.create_dataset('T1_krecon', data=self.T1_krecon) + f.create_dataset('T2_krecon', data=self.T2_krecon) + + + else: + # read it back + with h5py.File(h5_path, 'r') as f: + self.T1_images = f['T1_images'][()] + self.T2_images = f['T2_images'][()] + self.T1_labels = f['T1_labels'][()] + self.T2_labels = f['T2_labels'][()] + self.fname_slices = f['fname_slices'][()] + + if kspace_refine == 'True': + self.T1_krecon = f['T1_krecon'][()] + self.T2_krecon = f['T2_krecon'][()] + + # if kspace_refine == 'True': + # self.T1_krecon = self.T1_krecon.transpose(0,2,1,3,4).reshape(-1,len(T1_krecon_list),240,240) + # self.T2_krecon = self.T2_krecon.transpose(0,2,1,3,4).reshape(-1,len(T2_krecon_list),240,240) + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + + # print("T1_labels:", T1_labels.shape, T1_labels.dtype, T1_labels.max(), T1_labels.min()) + # print("T2_labels:", T2_labels.shape, T2_labels.dtype, T2_labels.max(), T2_labels.min()) + + choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + T1_images = T1_images[choices] + T2_images = T2_images[choices] + + t1_kspace_in, t1_in, t1_kspace, t1_img = mri_fft_m4raw(T1_images, T1_labels) + t2_kspace_in, t2_in, t2_kspace, t2_img = mri_fft_m4raw(T2_images, T2_labels) + + fname = self.fname_slices[idx][0] + slice = self.fname_slices[idx][1] + + # normalize + t1_img, t1_mean, t1_std = normalize_instance(t1_img) + t1_in = normalize(t1_in, t1_mean, t1_std) + + t2_img, t2_mean, t2_std = normalize_instance(t2_img) + t2_in = normalize(t2_in, t2_mean, t2_std) + + # filter value that greater or less than 6 + v = 10 + t1_img = torch.clamp(t1_img, -v, v) + t2_img = torch.clamp(t2_img, -v, v) + t1_in = torch.clamp(t1_in, -v, v) + t2_in = torch.clamp(t2_in, -v, v) + + + + # fname, slice + sample = { + 'fname': fname, + 'slice': slice, + + 'ref_kspace_full': t1_kspace, + 'ref_kspace_sub': t1_kspace_in, + 'ref_image_full': t1_img, + 'ref_image_sub': t1_in, + 't1_mean': t1_mean, + 't1_std': t1_std, + + 'tag_kspace_full': t2_kspace, + 'tag_kspace_sub': t2_kspace_in, + 'tag_image_full': t2_img, + 'tag_image_sub': t2_in, + 't2_mean': t2_mean, + 't2_std': t2_std, + + } + + return sample + + +def compute_metrics(image, labels): + MSE = mean_squared_error(labels, image)/np.var(labels) + PSNR = peak_signal_noise_ratio(labels, image) + SSIM = structural_similarity(labels, image) + + return MSE, PSNR, SSIM + + +def complex_abs_eval(data): + return (data[0:1, :, :] ** 2 + data[1:2, :, :] ** 2).sqrt() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--root_path', type=str, default='/data/qic99/MRI_recon/') + parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') + parser.add_argument('--kspace_refine', type=str, default='False', help='whether use the image reconstructed from kspace network.') + args = parser.parse_args() + + db_test = M4Raw_TestSet(args) + testloader = DataLoader(db_test, batch_size=4, shuffle=False, num_workers=4, pin_memory=True) + + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all = [], [], [] + + save_dir = "./visualize_images/" + + for i_batch, sampled_batch in enumerate(testloader): + # t1_in, t2_in = sampled_batch['t1_in'].cuda(), sampled_batch['t2_in'].cuda() + # t1, t2 = sampled_batch['t1_labels'].cuda(), sampled_batch['t2_labels'].cuda() + + t1_in, t2_in = sampled_batch['ref_image_sub'].cuda(), sampled_batch['tag_image_sub'].cuda() + t1, t2 = sampled_batch['ref_image_full'].cuda(), sampled_batch['tag_image_full'].cuda() + + # breakpoint() + for j in range(t1_in.shape[0]): + # t1_in_img = (np.clip(complex_abs_eval(t1_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(complex_abs_eval(t1[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(complex_abs_eval(t2_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(complex_abs_eval(t2[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # breakpoint() + t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + # t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + + # print(t1_in_img.shape, t1_img.shape) + t1_MSE, t1_PSNR, t1_SSIM = compute_metrics(t1_in_img, t1_img) + t2_MSE, t2_PSNR, t2_SSIM = compute_metrics(t2_in_img, t2_img) + + + t1_MSE_all.append(t1_MSE) + t1_PSNR_all.append(t1_PSNR) + t1_SSIM_all.append(t1_SSIM) + + t2_MSE_all.append(t2_MSE) + t2_PSNR_all.append(t2_PSNR) + t2_SSIM_all.append(t2_SSIM) + + + # print("t1_PSNR:", t1_PSNR_all) + print("t1_PSNR:", round(np.array(t1_PSNR_all).mean(), 4)) + print("t1_NMSE:", round(np.array(t1_MSE_all).mean(), 4)) + print("t1_SSIM:", round(np.array(t1_SSIM_all).mean(), 4)) + + print("t2_PSNR:", round(np.array(t2_PSNR_all).mean(), 4)) + print("t2_NMSE:", round(np.array(t2_MSE_all).mean(), 4)) + print("t2_SSIM:", round(np.array(t2_SSIM_all).mean(), 4)) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/m4raw_std_dataloader.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/m4raw_std_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..5b601d4590bc21d45014d37f935df17163f75822 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/m4raw_std_dataloader.py @@ -0,0 +1,663 @@ + +from __future__ import print_function, division +from typing import Dict, NamedTuple, Optional, Sequence, Tuple, Union +import sys +sys.path.append('.') +from glob import glob +import os +os.environ['OPENBLAS_NUM_THREADS'] = '1' +import numpy as np +import torch +from torch.utils.data import Dataset + +import h5py +from matplotlib import pyplot as plt +from dataloaders.m4_utils import ifft2c, fft2c, complex_abs +from dataloaders.kspace_subsample import create_mask_for_mask_type + +from .albu_transform import get_albu_transforms + +import argparse, time +from torch.utils.data import DataLoader +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + +def normalize_instance_dim(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean(dim=(1, 2, 3), keepdim=True) # B, C, H, W + std = data.std(dim=(1, 2, 3), keepdim=True) + + return normalize(data, mean, std, eps), mean, std + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +def undersample_mri(kspace, _MRIDOWN): + # print("kspace shape:", kspace.shape) ## [18, 4, 256, 256, 2] + if _MRIDOWN == "4X": + mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 + elif _MRIDOWN == "8X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + elif _MRIDOWN == "12X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.03, 12 + + ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 f| ------------------------------------------------------------------------------------------------------------------------------- + + shape = [256, 256, 1] + mask = ff(shape, seed=1337) ## [1, 256, 1] + + mask = mask[:, :, 0] # [1, 256] + + masked_kspace = kspace * mask[None, None, :, :, None] + + return masked_kspace, mask.unsqueeze(-1) + + +def apply_mask(data, mask_func, seed=None, padding=None): + """ + Subsample given k-space by multiplying with a mask. + + Args: + data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where + dimensions -3 and -2 are the spatial dimensions, and the final dimension has size + 2 (for complex values). + mask_func (callable): A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed (int or 1-d array_like, optional): Seed for the random number generator. + + Returns: + (tuple): tuple containing: + masked data (torch.Tensor): Subsampled k-space data + mask (torch.Tensor): The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + +def complex_center_crop(data, shape): + """ + Apply a center crop to the input image or batch of complex images. + + Args: + data (torch.Tensor): The complex input tensor to be center cropped. It + should have at least 3 dimensions and the cropping is applied along + dimensions -3 and -2 and the last dimensions should have a size of + 2. + shape (int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image + """ + assert 0 < shape[0] <= data.shape[-3] + assert 0 < shape[1] <= data.shape[-2] + + w_from = (data.shape[-3] - shape[0]) // 2 #80 + h_from = (data.shape[-2] - shape[1]) // 2 #80 + w_to = w_from + shape[0] #240 + h_to = h_from + shape[1] #240 + + return data[..., w_from:w_to, h_from:h_to, :] + + + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Args: + data (np.array): Input numpy array. + + Returns: + torch.Tensor: PyTorch version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data) + +def rss(data, dim=0): + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Args: + data (torch.Tensor): The input tensor + dim (int): The dimensions along which to apply the RSS transform + + Returns: + torch.Tensor: The RSS value. + """ + return torch.sqrt((data ** 2).sum(dim)) + +import imageio as io +def read_h5(file_name, _MRIDOWN, use_kspace, crop_size=[240,240]): + + + hf = h5py.File(file_name) + volume_kspace = hf['kspace'][()] + + slice_kspace = volume_kspace + slice_kspace = to_tensor(slice_kspace) + + target = ifft2c(slice_kspace) + target = complex_center_crop(target, crop_size) + target = complex_abs(target) + target = rss(target, dim=1) + + if not use_kspace: + masked_kspace, mask = undersample_mri(slice_kspace, _MRIDOWN) + lq_image = ifft2c(masked_kspace) + lq_image = complex_center_crop(lq_image, crop_size) + lq_image = complex_abs(lq_image) + lq_image = rss(lq_image, dim=1) + + else: + lq_image = target + + lq_image_list=[] + mean_list=[] + std_list=[] + v = 10 + for i in range(lq_image.shape[0]): + image, mean, std = normalize_instance(lq_image[i], eps=1e-11) + image = image.clamp(-v, v) + lq_image_list.append(image) + mean_list.append(mean) + std_list.append(std) + + + + target_list=[] + + for i in range(lq_image.shape[0]): + target_slice = normalize(target[i], mean_list[i], std_list[i], eps=1e-11) + target_slice = target_slice.clamp(-v, v) + target_list.append(target_slice) + + return torch.stack(lq_image_list), torch.stack(target_list), torch.stack(mean_list), torch.stack(std_list) + + + +import albumentations as A + +# Define a shift-only transform +shift_transform = transform = A.Compose([ + + A.ShiftScaleRotate(shift_limit=0.2, scale_limit=(-0.2, 0.2), + rotate_limit=5, p=0.5), + + A.OneOf([ + A.GridDistortion(num_steps=1, distort_limit=0.3, p=1.0), + A.ElasticTransform(alpha=2, sigma=5, p=1.0) + ], p=0.5), + +], additional_targets={ + 'image2': 'image', + 'image3': 'image', + 'image4': 'image' + } +) + + + +def chw_to_hwc(img): return np.moveaxis(img, 0, -1) +def hwc_to_chw(img): return np.moveaxis(img, -1, 0) +def to_batch(img): return np.expand_dims(img, 0) + + + +class M4Raw_TrainSet(Dataset): + def __init__(self, args, use_kspace=False, DEBUG=False, + h5_path=""): + + mask_func = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + start_time = time.time() + + self.albu_transforms = get_albu_transforms("train", (240, 240)) + + self._MRIDOWN = args.MRIDOWN + self.input_normalize = args.input_normalize + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_train', '*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_train', '*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + + T2_input_list = [input_list1, input_list2, input_list3] + + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 240, 240]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_mean = np.zeros([len(input_list2),len(T2_input_list), 18]) + self.T2_std = np.zeros([len(input_list2),len(T2_input_list), 18]) + + print('TrainSet loading...') + if not os.path.exists(h5_path): + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + _, self.T1_images[j][i], _, _ = read_h5(path, self._MRIDOWN, use_kspace=use_kspace) + + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_masked_images[j][i], self.T2_images[j][i], self.T2_mean[j][i], self.T2_std[j][i] = read_h5(path, self._MRIDOWN, use_kspace=use_kspace) + # lq_image_list, target_list + + + # self.T2_labels = np.mean(self.T2_images, axis=1) # TODO + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print(f'Finish loading with time = {time.time() - start_time}s') + + # print("T1 image original shape:", self.T1_images.shape) # T1 image original shape: (128, 3, 18, 256, 256) + # print("T2 image original shape:", self.T2_images.shape) + + N, _, S, H, W = self.T1_images.shape + self.fname_slices = [] + + for i in range(N): + for j in range(S): + self.fname_slices.append((i, j)) + + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),240,240) + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),240,240) + self.T1_labels = self.T1_labels.reshape(-1,1,240,240) + self.T2_labels = self.T2_labels.reshape(-1,1,240,240) + self.T2_mean = self.T2_mean.reshape(-1,3) + self.T2_std = self.T2_std.reshape(-1,3) + + print("Train data shape:", self.T1_images.shape) + + + with h5py.File(h5_path, 'w') as f: + f.create_dataset('T1_images', data=self.T1_images) + f.create_dataset('T2_images', data=self.T2_images) + f.create_dataset('T1_labels', data=self.T1_labels) + f.create_dataset('T2_labels', data=self.T2_labels) + f.create_dataset('T2_mean', data=self.T2_mean) + f.create_dataset('T2_std', data=self.T2_std) + + + f.create_dataset('fname_slices', data=np.array(self.fname_slices, dtype=np.int32)) # , dtype='S' + + + print("saved h5 file to", h5_path) + + else: + # read it back + with h5py.File(h5_path, 'r') as f: + self.T1_images = f['T1_images'][()] + self.T2_images = f['T2_images'][()] + self.T1_labels = f['T1_labels'][()] + self.T2_labels = f['T2_labels'][()] + self.fname_slices = f['fname_slices'][()] + + self.T2_mean = f['T2_mean'][()] + self.T2_std = f['T2_std'][()] + + print("read h5 file from", h5_path) + + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + + fname = self.fname_slices[idx][0] + slice = self.fname_slices[idx][1] + + + choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) ## 每次都是从三个repetition中选择一个作为input. + # choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + T1_images = T1_images[choices] + T2_images = T2_images[choices] + t2_mean = self.T2_mean[idx][choices] + t2_std = self.T2_std[idx][choices] + + # (1, 240, 240) + sample = self.albu_transforms(image=T1_images[0], image2=T2_images[0], + image3=T1_labels[0], image4=T2_labels[0]) + + # breakpoint() + + + transformed = shift_transform( + image=sample['image'], + image2=sample['image2'], + image3=sample['image3'], + image4=sample['image4'] + ) + + t1_in = np.expand_dims(transformed['image'], 0) + t2_in = np.expand_dims(transformed['image2'], 0) + t1 = np.expand_dims(transformed['image3'], 0) + t2 = np.expand_dims(transformed['image4'], 0) + + + + sample = { + 'fname': fname, + 'slice': slice, + + 't1_in': t1_in.astype(np.float32), + 't1': t1.astype(np.float32), + "t2_mean": t2_mean, "t2_std": t2_std, + + 't2_in': t2_in.astype(np.float32), + 't2': t2.astype(np.float32)} + + return sample #, sample_stats + + + +class M4Raw_TestSet(Dataset): + def __init__(self, args, use_kspace=False, DEBUG=False, h5_path=""): + + mask_func = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + self.use_kspace = use_kspace + self._MRIDOWN = args.MRIDOWN + + self.input_normalize = args.input_normalize + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_val' + '*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_val', '*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + + T2_input_list = [input_list1,input_list2,input_list3] + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 240, 240]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_mean = np.zeros([len(input_list2),len(T2_input_list), 18]) + self.T2_std = np.zeros([len(input_list2),len(T2_input_list), 18]) + + + print('TestSet loading...') + if not os.path.exists(h5_path): + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + _, self.T1_images[j][i], _, _ = read_h5(path, self._MRIDOWN, use_kspace=use_kspace) + + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_masked_images[j][i], self.T2_images[j][i], self.T2_mean[j][i], self.T2_std[j][i] = read_h5(path, + self._MRIDOWN, + use_kspace=use_kspace) + # lq_image_list, target_list + + # self.T2_labels = np.mean(self.T2_images, axis=1) # TODO + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print('Finish loading') + N, _, S, H, W = self.T1_images.shape + self.fname_slices = [] + + for i in range(N): + for j in range(S): + self.fname_slices.append((i, j)) + + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),240,240) + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),240,240) + self.T1_labels = self.T1_labels.reshape(-1,1,240,240) + self.T2_labels = self.T2_labels.reshape(-1,1,240,240) + self.T2_mean = self.T2_mean.reshape(-1,3) + self.T2_std = self.T2_std.reshape(-1,3) + + print("Train data shape:", self.T1_images.shape) + print("try saving to :", h5_path) + + with h5py.File(h5_path, 'w') as f: + f.create_dataset('T1_images', data=self.T1_images) + f.create_dataset('T2_images', data=self.T2_images) + f.create_dataset('T1_labels', data=self.T1_labels) + f.create_dataset('T2_labels', data=self.T2_labels) + f.create_dataset('T2_mean', data=self.T2_mean) + f.create_dataset('T2_std', data=self.T2_std) + + f.create_dataset('fname_slices', data=np.array(self.fname_slices, dtype=np.int32)) + print("saved h5 file to", h5_path) + + else: + with h5py.File(h5_path, 'r') as f: + self.T1_images = f['T1_images'][:] + self.T2_images = f['T2_images'][:] + self.T1_labels = f['T1_labels'][:] + self.T2_labels = f['T2_labels'][:] + self.T2_mean = f['T2_mean'][:] + self.T2_std = f['T2_std'][:] + self.fname_slices = f['fname_slices'][:] + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + + + fname = self.fname_slices[idx][0] + slice = self.fname_slices[idx][1] + + + + # choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) ## 每次都是从三个repetition中选择一个作为input. + choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + T1_images = T1_images[choices] + T2_images = T2_images[choices] + + t2_mean = self.T2_mean[idx][choices] + t2_std = self.T2_std[idx][choices] + + + t1_in = T1_images + t1 = T1_labels + t2_in = T2_images + t2 = T2_labels + + # sample_stats = {"t2_mean": t2_mean, "t2_std": t2_std} + + # print("Test t1_in shape:", t1_in.shape, "t1 shape:", t1.shape, "t2_in shape:", t2_in.shape, "t2 shape:", t2.shape) + + # breakpoint() + sample = { + 'fname': fname, + 'slice': slice, + + 't1_in': t1_in.astype(np.float32), + 't1': t1.astype(np.float32), + "t2_mean": t2_mean, "t2_std": t2_std, + + 't2_in': t2_in.astype(np.float32), + 't2': t2.astype(np.float32)} + + return sample #, sample_stats + + + +def compute_metrics(image, labels): + MSE = mean_squared_error(labels, image)/np.var(labels) + PSNR = peak_signal_noise_ratio(labels, image) + SSIM = structural_similarity(labels, image) + + # print("metrics:", MSE, PSNR, SSIM) + + return MSE, PSNR, SSIM + +def complex_abs_eval(data): + return (data[0:1, :, :] ** 2 + data[1:2, :, :] ** 2).sqrt() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--root_path', type=str, default='/data/qic99/MRI_recon/') + parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') + parser.add_argument('--kspace_refine', type=str, default='False', help='whether use the image reconstructed from kspace network.') + args = parser.parse_args() + + db_test = M4Raw_TestSet(args) + testloader = DataLoader(db_test, batch_size=4, shuffle=False, num_workers=4, pin_memory=True) + + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all = [], [], [] + + save_dir = "./visualize_images/" + + for i_batch, sampled_batch in enumerate(testloader): + # t1_in, t2_in = sampled_batch['t1_in'].cuda(), sampled_batch['t2_in'].cuda() + # t1, t2 = sampled_batch['t1_labels'].cuda(), sampled_batch['t2_labels'].cuda() + + t1_in, t2_in = sampled_batch['ref_image_sub'].cuda(), sampled_batch['tag_image_sub'].cuda() + t1, t2 = sampled_batch['ref_image_full'].cuda(), sampled_batch['tag_image_full'].cuda() + + # breakpoint() + for j in range(t1_in.shape[0]): + # t1_in_img = (np.clip(complex_abs_eval(t1_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(complex_abs_eval(t1[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(complex_abs_eval(t2_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(complex_abs_eval(t2[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # breakpoint() + t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + # t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + + # print(t1_in_img.shape, t1_img.shape) + t1_MSE, t1_PSNR, t1_SSIM = compute_metrics(t1_in_img, t1_img) + t2_MSE, t2_PSNR, t2_SSIM = compute_metrics(t2_in_img, t2_img) + + + t1_MSE_all.append(t1_MSE) + t1_PSNR_all.append(t1_PSNR) + t1_SSIM_all.append(t1_SSIM) + + t2_MSE_all.append(t2_MSE) + t2_PSNR_all.append(t2_PSNR) + t2_SSIM_all.append(t2_SSIM) + + + # print("t1_PSNR:", t1_PSNR_all) + print("t1_PSNR:", round(np.array(t1_PSNR_all).mean(), 4)) + print("t1_NMSE:", round(np.array(t1_MSE_all).mean(), 4)) + print("t1_SSIM:", round(np.array(t1_SSIM_all).mean(), 4)) + + print("t2_PSNR:", round(np.array(t2_PSNR_all).mean(), 4)) + print("t2_NMSE:", round(np.array(t2_MSE_all).mean(), 4)) + print("t2_SSIM:", round(np.array(t2_SSIM_all).mean(), 4)) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/math.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/math.py new file mode 100644 index 0000000000000000000000000000000000000000..120b9f0501b1ef187228e2650413d675b307d1cb --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/math.py @@ -0,0 +1,231 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +import numpy as np + +def complex_mul(x, y): + """ + Complex multiplication. + + This multiplies two complex tensors assuming that they are both stored as + real arrays with the last dimension being the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == y.shape[-1] == 2 + re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] + im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] + + return torch.stack((re, im), dim=-1) + + +def complex_conj(x): + """ + Complex conjugate. + + This applies the complex conjugate assuming that the input array has the + last dimension as the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == 2 + + return torch.stack((x[..., 0], -x[..., 1]), dim=-1) + + + + + +def fft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2 dimensional Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.fft``. + + Returns: + The FFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.fftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.ifft``. + + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + + + + +def complex_abs(data): + """ + Compute the absolute value of a complex valued input tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Absolute value of data. + """ + assert data.size(-1) == 2 + + return (data ** 2).sum(dim=-1).sqrt() + + + +def complex_abs_numpy(data): + assert data.shape[-1] == 2 + + return np.sqrt(np.sum(data ** 2, axis=-1)) + + +def complex_abs_sq(data):#multi coil + """ + Compute the squared absolute value of a complex tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Squared absolute value of data. + """ + assert data.size(-1) == 2 + return (data ** 2).sum(dim=-1) + + +# Helper functions + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data + """ + data = data.numpy() + return data[..., 0] + 1j * data[..., 1] diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/new_m4raw_std_dataloader.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/new_m4raw_std_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..facf0d7c3d2febdf16ee1b1ec7f084e18cea36e2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/new_m4raw_std_dataloader.py @@ -0,0 +1,630 @@ + +from __future__ import print_function, division +from typing import Dict, NamedTuple, Optional, Sequence, Tuple, Union +import sys +sys.path.append('.') +from glob import glob +import os +os.environ['OPENBLAS_NUM_THREADS'] = '1' +import numpy as np +import torch +from torch.utils.data import Dataset + +import h5py +from matplotlib import pyplot as plt +from dataloaders.m4_utils import ifft2c, fft2c, complex_abs +from dataloaders.kspace_subsample import create_mask_for_mask_type + +from .albu_transform import get_albu_transforms + +import argparse, time +from torch.utils.data import DataLoader +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + +def normalize_instance_dim(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean(dim=(1, 2, 3), keepdim=True) # B, C, H, W + std = data.std(dim=(1, 2, 3), keepdim=True) + + return normalize(data, mean, std, eps), mean, std + + +def normalize_instance(data, eps=1e-6): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +def undersample_mri(kspace, _MRIDOWN): + # print("kspace shape:", kspace.shape) ## [18, 4, 256, 256, 2] + if _MRIDOWN == "4X": + mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 + elif _MRIDOWN == "8X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + elif _MRIDOWN == "12X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.03, 12 + + ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 f| ------------------------------------------------------------------------------------------------------------------------------- + + shape = [256, 256, 1] + mask = ff(shape, seed=1337) ## [1, 256, 1] + + mask = mask[:, :, 0] # [1, 256] + + masked_kspace = kspace * mask[None, None, :, :, None] + + return masked_kspace, mask.unsqueeze(-1) + + +def apply_mask(data, mask_func, seed=None, padding=None): + """ + Subsample given k-space by multiplying with a mask. + + Args: + data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where + dimensions -3 and -2 are the spatial dimensions, and the final dimension has size + 2 (for complex values). + mask_func (callable): A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed (int or 1-d array_like, optional): Seed for the random number generator. + + Returns: + (tuple): tuple containing: + masked data (torch.Tensor): Subsampled k-space data + mask (torch.Tensor): The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + +def complex_center_crop(data, shape): + """ + Apply a center crop to the input image or batch of complex images. + + Args: + data (torch.Tensor): The complex input tensor to be center cropped. It + should have at least 3 dimensions and the cropping is applied along + dimensions -3 and -2 and the last dimensions should have a size of + 2. + shape (int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image + """ + assert 0 < shape[0] <= data.shape[-3] + assert 0 < shape[1] <= data.shape[-2] + + w_from = (data.shape[-3] - shape[0]) // 2 #80 + h_from = (data.shape[-2] - shape[1]) // 2 #80 + w_to = w_from + shape[0] #240 + h_to = h_from + shape[1] #240 + + return data[..., w_from:w_to, h_from:h_to, :] + + + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Args: + data (np.array): Input numpy array. + + Returns: + torch.Tensor: PyTorch version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data) + +def rss(data, dim=0): + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Args: + data (torch.Tensor): The input tensor + dim (int): The dimensions along which to apply the RSS transform + + Returns: + torch.Tensor: The RSS value. + """ + return torch.sqrt((data ** 2).sum(dim)) + +import imageio as io +def read_h5(file_name, _MRIDOWN, use_kspace, crop_size=[240,240]): + + + hf = h5py.File(file_name) + volume_kspace = hf['kspace'][()] + + slice_kspace = volume_kspace + slice_kspace = to_tensor(slice_kspace) + + target = ifft2c(slice_kspace) + target = complex_center_crop(target, crop_size) + target = complex_abs(target) + target = rss(target, dim=1) + + if not use_kspace: + masked_kspace, mask = undersample_mri(slice_kspace, _MRIDOWN) + lq_image = ifft2c(masked_kspace) + lq_image = complex_center_crop(lq_image, crop_size) + lq_image = complex_abs(lq_image) + lq_image = rss(lq_image, dim=1) + + else: + lq_image = target + + lq_image_list=[] + mean_list=[] + std_list=[] + v = 10 + for i in range(lq_image.shape[0]): + image, mean, std = normalize_instance(lq_image[i], eps=1e-11) + image = image.clamp(-v, v) + lq_image_list.append(image) + mean_list.append(mean) + std_list.append(std) + + + + target_list=[] + + for i in range(lq_image.shape[0]): + target_slice = normalize(target[i], mean_list[i], std_list[i], eps=1e-11) + target_slice = target_slice.clamp(-v, v) + target_list.append(target_slice) + + return torch.stack(lq_image_list), torch.stack(target_list), torch.stack(mean_list), torch.stack(std_list) + + + + + +class M4Raw_TrainSet(Dataset): + def __init__(self, args, use_kspace=False, DEBUG=False, + h5_path=""): + + mask_func = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + start_time = time.time() + + self.albu_transforms = get_albu_transforms("train", (args.image_size, args.image_size), + shift_limit=0.01, + scale_limit=(-0.01, 0.01), + rotate_limit=1, + distort_limit=0.01, # + elastic_alpha=1, elastic_sigma=1 + ) + + self._MRIDOWN = args.MRIDOWN + self.input_normalize = args.input_normalize + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_train', '*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_train', '*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + + T2_input_list = [input_list1, input_list2, input_list3] + + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 240, 240]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_mean = np.zeros([len(input_list2),len(T2_input_list), 18]) + self.T2_std = np.zeros([len(input_list2),len(T2_input_list), 18]) + + print('TrainSet loading...') + if not os.path.exists(h5_path): + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + _, self.T1_images[j][i], _, _ = read_h5(path, self._MRIDOWN, use_kspace=use_kspace) + + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_masked_images[j][i], self.T2_images[j][i], self.T2_mean[j][i], self.T2_std[j][i] = read_h5(path, self._MRIDOWN, use_kspace=use_kspace) + # lq_image_list, target_list + + + # self.T2_labels = np.mean(self.T2_images, axis=1) # TODO + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print(f'Finish loading with time = {time.time() - start_time}s') + + # print("T1 image original shape:", self.T1_images.shape) # T1 image original shape: (128, 3, 18, 256, 256) + # print("T2 image original shape:", self.T2_images.shape) + + N, _, S, H, W = self.T1_images.shape + self.fname_slices = [] + + for i in range(N): + for j in range(S): + self.fname_slices.append((i, j)) + + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),240,240) + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),240,240) + self.T1_labels = self.T1_labels.reshape(-1,1,240,240) + self.T2_labels = self.T2_labels.reshape(-1,1,240,240) + self.T2_mean = self.T2_mean.reshape(-1,3) + self.T2_std = self.T2_std.reshape(-1,3) + + print("Train data shape:", self.T1_images.shape) + + + with h5py.File(h5_path, 'w') as f: + f.create_dataset('T1_images', data=self.T1_images) + f.create_dataset('T2_images', data=self.T2_images) + f.create_dataset('T1_labels', data=self.T1_labels) + f.create_dataset('T2_labels', data=self.T2_labels) + f.create_dataset('T2_mean', data=self.T2_mean) + f.create_dataset('T2_std', data=self.T2_std) + + + f.create_dataset('fname_slices', data=np.array(self.fname_slices, dtype=np.int32)) # , dtype='S' + + + print("saved h5 file to", h5_path) + + else: + # read it back + with h5py.File(h5_path, 'r') as f: + self.T1_images = f['T1_images'][()] + self.T2_images = f['T2_images'][()] + self.T1_labels = f['T1_labels'][()] + self.T2_labels = f['T2_labels'][()] + self.fname_slices = f['fname_slices'][()] + + self.T2_mean = f['T2_mean'][()] + self.T2_std = f['T2_std'][()] + + print("read h5 file from", h5_path) + + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + + fname = self.fname_slices[idx][0] + slice = self.fname_slices[idx][1] + + + choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) ## 每次都是从三个repetition中选择一个作为input. + # choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + T1_images = T1_images[choices] + T2_images = T2_images[choices] + t2_mean = self.T2_mean[idx][choices] + t2_std = self.T2_std[idx][choices] + + # (1, 240, 240) + sample = self.albu_transforms(image=T1_images[0], image2=T2_images[0], + image3=T1_labels[0], image4=T2_labels[0]) + + # breakpoint() + t1_in = np.expand_dims(sample['image'], 0) + t2_in = np.expand_dims(sample['image2'], 0) + t1 = np.expand_dims(sample['image3'], 0) + t2 = np.expand_dims(sample['image4'], 0) + + + + # breakpoint() + sample = { + 'fname': fname, + 'slice': slice, + + 't1_in': t1_in.astype(np.float32), + 't1': t1.astype(np.float32), + "t2_mean": t2_mean, "t2_std": t2_std, + + 't2_in': t2_in.astype(np.float32), + 't2': t2.astype(np.float32)} + + return sample #, sample_stats + + + +class M4Raw_TestSet(Dataset): + def __init__(self, args, use_kspace=False, DEBUG=False, h5_path=""): + + mask_func = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + self.use_kspace = use_kspace + self._MRIDOWN = args.MRIDOWN + + self.input_normalize = args.input_normalize + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_val' + '*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + + self.albu_transforms = get_albu_transforms("test", (args.image_size, args.image_size)) + + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_val', '*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + if DEBUG: + input_list1 = input_list1[:2] + input_list2 = input_list2[:2] + input_list3 = input_list3[:2] + + + T2_input_list = [input_list1,input_list2,input_list3] + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 240, 240]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_mean = np.zeros([len(input_list2),len(T2_input_list), 18]) + self.T2_std = np.zeros([len(input_list2),len(T2_input_list), 18]) + + + print('TestSet loading...') + if not os.path.exists(h5_path): + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + _, self.T1_images[j][i], _, _ = read_h5(path, self._MRIDOWN, use_kspace=use_kspace) + + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_masked_images[j][i], self.T2_images[j][i], self.T2_mean[j][i], self.T2_std[j][i] = read_h5(path, + self._MRIDOWN, + use_kspace=use_kspace) + # lq_image_list, target_list + + # self.T2_labels = np.mean(self.T2_images, axis=1) # TODO + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print('Finish loading') + N, _, S, H, W = self.T1_images.shape + self.fname_slices = [] + + for i in range(N): + for j in range(S): + self.fname_slices.append((i, j)) + + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),240,240) + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),240,240) + self.T1_labels = self.T1_labels.reshape(-1,1,240,240) + self.T2_labels = self.T2_labels.reshape(-1,1,240,240) + self.T2_mean = self.T2_mean.reshape(-1,3) + self.T2_std = self.T2_std.reshape(-1,3) + print("Train data shape:", self.T1_images.shape) + + with h5py.File(h5_path, 'w') as f: + f.create_dataset('T1_images', data=self.T1_images) + f.create_dataset('T2_images', data=self.T2_images) + f.create_dataset('T1_labels', data=self.T1_labels) + f.create_dataset('T2_labels', data=self.T2_labels) + f.create_dataset('T2_mean', data=self.T2_mean) + f.create_dataset('T2_std', data=self.T2_std) + + f.create_dataset('fname_slices', data=np.array(self.fname_slices, dtype=np.int32)) + print("saved h5 file to", h5_path) + else: + with h5py.File(h5_path, 'r') as f: + self.T1_images = f['T1_images'][:] + self.T2_images = f['T2_images'][:] + self.T1_labels = f['T1_labels'][:] + self.T2_labels = f['T2_labels'][:] + self.T2_mean = f['T2_mean'][:] + self.T2_std = f['T2_std'][:] + self.fname_slices = f['fname_slices'][:] + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + + + fname = self.fname_slices[idx][0] + slice = self.fname_slices[idx][1] + + + + # choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) ## 每次都是从三个repetition中选择一个作为input. + choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + T1_images = T1_images[choices] + T2_images = T2_images[choices] + t2_mean = self.T2_mean[idx][choices] + t2_std = self.T2_std[idx][choices] + + # (1, 240, 240) + sample = self.albu_transforms(image=T1_images[0], image2=T2_images[0], + image3=T1_labels[0], image4=T2_labels[0]) + + # breakpoint() + t1_in = np.expand_dims(sample['image'], 0) + t2_in = np.expand_dims(sample['image2'], 0) + t1 = np.expand_dims(sample['image3'], 0) + t2 = np.expand_dims(sample['image4'], 0) + + + + # breakpoint() + sample = { + 'fname': fname, + 'slice': slice, + + 't1_in': t1_in.astype(np.float32), + 't1': t1.astype(np.float32), + "t2_mean": t2_mean, "t2_std": t2_std, + + 't2_in': t2_in.astype(np.float32), + 't2': t2.astype(np.float32)} + + return sample #, sample_stats + + + +def compute_metrics(image, labels): + MSE = mean_squared_error(labels, image)/np.var(labels) + PSNR = peak_signal_noise_ratio(labels, image) + SSIM = structural_similarity(labels, image) + + # print("metrics:", MSE, PSNR, SSIM) + + return MSE, PSNR, SSIM + +def complex_abs_eval(data): + return (data[0:1, :, :] ** 2 + data[1:2, :, :] ** 2).sqrt() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--root_path', type=str, default='/data/qic99/MRI_recon/') + parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') + parser.add_argument('--kspace_refine', type=str, default='False', help='whether use the image reconstructed from kspace network.') + args = parser.parse_args() + + db_test = M4Raw_TestSet(args) + testloader = DataLoader(db_test, batch_size=4, shuffle=False, num_workers=4, pin_memory=True) + + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all = [], [], [] + + save_dir = "./visualize_images/" + + for i_batch, sampled_batch in enumerate(testloader): + # t1_in, t2_in = sampled_batch['t1_in'].cuda(), sampled_batch['t2_in'].cuda() + # t1, t2 = sampled_batch['t1_labels'].cuda(), sampled_batch['t2_labels'].cuda() + + t1_in, t2_in = sampled_batch['ref_image_sub'].cuda(), sampled_batch['tag_image_sub'].cuda() + t1, t2 = sampled_batch['ref_image_full'].cuda(), sampled_batch['tag_image_full'].cuda() + + # breakpoint() + for j in range(t1_in.shape[0]): + # t1_in_img = (np.clip(complex_abs_eval(t1_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(complex_abs_eval(t1[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(complex_abs_eval(t2_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(complex_abs_eval(t2[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # breakpoint() + t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + + # print(t1_in_img.shape, t1_img.shape) + t1_MSE, t1_PSNR, t1_SSIM = compute_metrics(t1_in_img, t1_img) + t2_MSE, t2_PSNR, t2_SSIM = compute_metrics(t2_in_img, t2_img) + + + t1_MSE_all.append(t1_MSE) + t1_PSNR_all.append(t1_PSNR) + t1_SSIM_all.append(t1_SSIM) + + t2_MSE_all.append(t2_MSE) + t2_PSNR_all.append(t2_PSNR) + t2_SSIM_all.append(t2_SSIM) + + + # print("t1_PSNR:", t1_PSNR_all) + print("t1_PSNR:", round(np.array(t1_PSNR_all).mean(), 4)) + print("t1_NMSE:", round(np.array(t1_MSE_all).mean(), 4)) + print("t1_SSIM:", round(np.array(t1_SSIM_all).mean(), 4)) + + print("t2_PSNR:", round(np.array(t2_PSNR_all).mean(), 4)) + print("t2_NMSE:", round(np.array(t2_MSE_all).mean(), 4)) + print("t2_SSIM:", round(np.array(t2_SSIM_all).mean(), 4)) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/subsample.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/subsample.py new file mode 100644 index 0000000000000000000000000000000000000000..d0620da3414c6077e4293376fb8a9be01ad19990 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/subsample.py @@ -0,0 +1,195 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/transforms.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d9cecc761fc46e201705992ce6226598492f76af --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/dataloaders/transforms.py @@ -0,0 +1,493 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import numpy as np +import torch + +from .math import ifft2c, fft2c, complex_abs +from .subsample import create_mask_for_mask_type, MaskFunc +import random + +from typing import Dict, Optional, Sequence, Tuple, Union +from matplotlib import pyplot as plt +import os + +def rss(data, dim=0): + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Args: + data (torch.Tensor): The input tensor + dim (int): The dimensions along which to apply the RSS transform + + Returns: + torch.Tensor: The RSS value. + """ + return torch.sqrt((data ** 2).sum(dim)) + + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Args: + data (np.array): Input numpy array. + + Returns: + torch.Tensor: PyTorch version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data. + """ + data = data.numpy() + + return data[..., 0] + 1j * data[..., 1] + + +def apply_mask(data, mask_func, seed=None, padding=None): + """ + Subsample given k-space by multiplying with a mask. + + Args: + data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where + dimensions -3 and -2 are the spatial dimensions, and the final dimension has size + 2 (for complex values). + mask_func (callable): A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed (int or 1-d array_like, optional): Seed for the random number generator. + + Returns: + (tuple): tuple containing: + masked data (torch.Tensor): Subsampled k-space data + mask (torch.Tensor): The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + + +def mask_center(x, mask_from, mask_to): + mask = torch.zeros_like(x) + mask[:, :, :, mask_from:mask_to] = x[:, :, :, mask_from:mask_to] + + return mask + + +def center_crop(data, shape): + """ + Apply a center crop to the input real image or batch of real images. + + Args: + data (torch.Tensor): The input tensor to be center cropped. It should + have at least 2 dimensions and the cropping is applied along the + last two dimensions. + shape (int, int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image. + """ + assert 0 < shape[0] <= data.shape[-2] + assert 0 < shape[1] <= data.shape[-1] + + w_from = (data.shape[-2] - shape[0]) // 2 + h_from = (data.shape[-1] - shape[1]) // 2 + w_to = w_from + shape[0] + h_to = h_from + shape[1] + + return data[..., w_from:w_to, h_from:h_to] + + +def complex_center_crop(data, shape): + """ + Apply a center crop to the input image or batch of complex images. + + Args: + data (torch.Tensor): The complex input tensor to be center cropped. It + should have at least 3 dimensions and the cropping is applied along + dimensions -3 and -2 and the last dimensions should have a size of + 2. + shape (int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image + """ + assert 0 < shape[0] <= data.shape[-3] + assert 0 < shape[1] <= data.shape[-2] + + w_from = (data.shape[-3] - shape[0]) // 2 #80 + h_from = (data.shape[-2] - shape[1]) // 2 #80 + w_to = w_from + shape[0] #240 + h_to = h_from + shape[1] #240 + + return data[..., w_from:w_to, h_from:h_to, :] + + +def center_crop_to_smallest(x, y): + """ + Apply a center crop on the larger image to the size of the smaller. + + The minimum is taken over dim=-1 and dim=-2. If x is smaller than y at + dim=-1 and y is smaller than x at dim=-2, then the returned dimension will + be a mixture of the two. + + Args: + x (torch.Tensor): The first image. + y (torch.Tensor): The second image + + Returns: + tuple: tuple of tensors x and y, each cropped to the minimim size. + """ + smallest_width = min(x.shape[-1], y.shape[-1]) + smallest_height = min(x.shape[-2], y.shape[-2]) + x = center_crop(x, (smallest_height, smallest_width)) + y = center_crop(y, (smallest_height, smallest_width)) + + return x, y + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + +class DataTransform(object): + """ + Data Transformer for training U-Net models. + """ + + def __init__(self, which_challenge): + """ + Args: + which_challenge (str): Either "singlecoil" or "multicoil" denoting + the dataset. + mask_func (fastmri.data.subsample.MaskFunc): A function that can + create a mask of appropriate shape. + use_seed (bool): If true, this class computes a pseudo random + number generator seed from the filename. This ensures that the + same mask is used for all the slices of a given volume every + time. + """ + if which_challenge not in ("singlecoil", "multicoil"): + raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"') + + self.which_challenge = which_challenge + + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + """ + Args: + kspace (numpy.array): Input k-space of shape (num_coils, rows, + cols, 2) for multi-coil data or (rows, cols, 2) for single coil + data. + mask (numpy.array): Mask from the test dataset. + target (numpy.array): Target image. + attrs (dict): Acquisition related information stored in the HDF5 + object. + fname (str): File name. + slice_num (int): Serial number of the slice. + + Returns: + (tuple): tuple containing: + image (torch.Tensor): Zero-filled input image. + target (torch.Tensor): Target image converted to a torch + Tensor. + mean (float): Mean value used for normalization. + std (float): Standard deviation value used for normalization. + fname (str): File name. + slice_num (int): Serial number of the slice. + """ + kspace = to_tensor(kspace) + + # inverse Fourier transform to get zero filled solution + image = ifft2c(kspace) + + # crop input to correct size + if target is not None: + crop_size = (target.shape[-2], target.shape[-1]) + else: + crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) + + # check for sFLAIR 203 + if image.shape[-2] < crop_size[1]: + crop_size = (image.shape[-2], image.shape[-2]) + + image = complex_center_crop(image, crop_size) + + # getLR + imgfft = fft2c(image) + imgfft = complex_center_crop(imgfft, (160, 160)) + LR_image = ifft2c(imgfft) + + # absolute value + LR_image = complex_abs(LR_image) + + # normalize input + LR_image, mean, std = normalize_instance(LR_image, eps=1e-11) + LR_image = LR_image.clamp(-6, 6) + + # normalize target + if target is not None: + target = to_tensor(target) + target = center_crop(target, crop_size) + target = normalize(target, mean, std, eps=1e-11) + target = target.clamp(-6, 6) + else: + target = torch.Tensor([0]) + + return LR_image, target, mean, std, fname, slice_num + +class DenoiseDataTransform(object): + def __init__(self, size, noise_rate): + super(DenoiseDataTransform, self).__init__() + self.size = (size, size) + self.noise_rate = noise_rate + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + max_value = attrs["max"] + + #target + target = to_tensor(target) + target = center_crop(target, self.size) + target, mean, std = normalize_instance(target, eps=1e-11) + target = target.clamp(-6, 6) + + #image + kspace = to_tensor(kspace) + complex_image = ifft2c(kspace) #complex_image + image = complex_center_crop(complex_image, self.size) + noise_image = self.rician_noise(image, max_value) + noise_image = complex_abs(noise_image) + + noise_image = normalize(noise_image, mean, std, eps=1e-11) + noise_image = noise_image.clamp(-6, 6) + + return noise_image, target, mean, std, fname, slice_num + + + def rician_noise(self, X, noise_std): + #Add rician noise with variance sampled uniformly from the range 0 and 0.1 + noise_std = random.uniform(0, noise_std*self.noise_rate) + Ir = X + noise_std * torch.randn(X.shape) + Ii = noise_std*torch.randn(X.shape) + In = torch.sqrt(Ir ** 2 + Ii ** 2) + return In + + +def apply_mask( + data: torch.Tensor, + mask_func: MaskFunc, + seed: Optional[Union[int, Tuple[int, ...]]] = None, + padding: Optional[Sequence[int]] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Subsample given k-space by multiplying with a mask. + Args: + data: The input k-space data. This should have at least 3 dimensions, + where dimensions -3 and -2 are the spatial dimensions, and the + final dimension has size 2 (for complex values). + mask_func: A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed: Seed for the random number generator. + padding: Padding value to apply for mask. + Returns: + tuple containing: + masked data: Subsampled k-space data + mask: The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + + + +class ReconstructionTransform(object): + """ + Data Transformer for training U-Net models. + """ + + def __init__(self, which_challenge, mask_func=None, use_seed=True): + """ + Args: + which_challenge (str): Either "singlecoil" or "multicoil" denoting + the dataset. + mask_func (fastmri.data.subsample.MaskFunc): A function that can + create a mask of appropriate shape. + use_seed (bool): If true, this class computes a pseudo random + number generator seed from the filename. This ensures that the + same mask is used for all the slices of a given volume every + time. + """ + if which_challenge not in ("singlecoil", "multicoil"): + raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"') + + self.mask_func = mask_func + self.which_challenge = which_challenge + self.use_seed = use_seed + + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + """ + Args: + kspace (numpy.array): Input k-space of shape (num_coils, rows, + cols, 2) for multi-coil data or (rows, cols, 2) for single coil + data. + mask (numpy.array): Mask from the test dataset. + target (numpy.array): Target image. + attrs (dict): Acquisition related information stored in the HDF5 + object. + fname (str): File name. + slice_num (int): Serial number of the slice. + + Returns: + (tuple): tuple containing: + image (torch.Tensor): Zero-filled input image. + target (torch.Tensor): Target image converted to a torch + Tensor. + mean (float): Mean value used for normalization. + std (float): Standard deviation value used for normalization. + fname (str): File name. + slice_num (int): Serial number of the slice. + """ + kspace = to_tensor(kspace) + + # apply mask + if self.mask_func: + seed = None if not self.use_seed else tuple(map(ord, fname)) + masked_kspace, mask = apply_mask(kspace, self.mask_func, seed) + # print("mask shape", mask.shape, mask.sum()) + # mask shape torch.Size([1, 368, 1]) tensor(89.) + + else: + masked_kspace = kspace + # print("masked_kspace shape", masked_kspace.shape) + + # inverse Fourier transform to get zero filled solution + image = ifft2c(masked_kspace) + + # crop input to correct size + if target is not None: + crop_size = (target.shape[-2], target.shape[-1]) + else: + crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) + + # check for sFLAIR 203 + if image.shape[-2] < crop_size[1]: + crop_size = (image.shape[-2], image.shape[-2]) + + image = complex_center_crop(image, crop_size) + # print('image',image.shape) + # absolute value + image = complex_abs(image) + + # apply Root-Sum-of-Squares if multicoil data + if self.which_challenge == "multicoil": + image = rss(image) + + # normalize input + image, mean, std = normalize_instance(image, eps=1e-11) + image = image.clamp(-6, 6) + + # normalize target + if target is not None: + target = to_tensor(target) + target = center_crop(target, crop_size) + target = normalize(target, mean, std, eps=1e-11) + target = target.clamp(-6, 6) + else: + target = torch.Tensor([0]) + + return image, target, mean, std, fname, slice_num + + +def build_transforms(args, mode = 'train', use_kspace=False): + + challenge = 'singlecoil' + if use_kspace: + return ReconstructionTransform(challenge) + + else: + if mode == 'train': + mask = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + return ReconstructionTransform(challenge, mask, use_seed=False) + elif mode == 'val': + mask = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + return ReconstructionTransform(challenge, mask) + else: + return ReconstructionTransform(challenge) + + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/documents/INSTALL.md b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/documents/INSTALL.md new file mode 100644 index 0000000000000000000000000000000000000000..9912721cb3354240d99c08838ae8d2b1417b339b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/documents/INSTALL.md @@ -0,0 +1,11 @@ +## Dependency +The code is tested on `python 3.8, Pytorch 1.13`. + +##### Setup environment + +```bash +conda create -n FSMNet python=3.8 +source activate FSMNet # or conda activate FSMNet +pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 +pip install einops h5py matplotlib scikit_image tensorboardX yacs pandas opencv-python timm ml_collections +``` diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe1e0f26ca039d666189f901309dbb9adfbadc89 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/__init__.py @@ -0,0 +1,2 @@ +from .frequency_noise import add_frequency_noise +from .degradation.k_degradation import get_ksu_kernel, apply_tofre, apply_to_spatial \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/__pycache__/__init__.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d83fa49875d0c0cf61d346c86045b8c43354d0c Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/__pycache__/frequency_noise.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/__pycache__/frequency_noise.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b00f5883615826142c8f149ba2d9323768f4450 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/__pycache__/frequency_noise.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/__pycache__/__init__.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ebc972e3f596c0d928b68a8b0b77dc0e97ed7a3 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/__pycache__/__init__.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/__pycache__/k_degradation.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/__pycache__/k_degradation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8980021d850174aa028ece013013329d6b55f406 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/__pycache__/k_degradation.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/__pycache__/mask_utils.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/__pycache__/mask_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..162f1e1826b6b5903e8e85f1166874c1dae3b768 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/__pycache__/mask_utils.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/extract_example_mask.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/extract_example_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1955ea8bf7c2e7d678e80063002dfc6572e7b9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/extract_example_mask.py @@ -0,0 +1,71 @@ +import matplotlib.pyplot as plt +import torch +import numpy as np +from torch.fft import fft2, ifft2, fftshift, ifftshift + +# brats 4X +example = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_4X/BraTS20_Training_036_90_t2_4X_undermri.png" +gt = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_4X/BraTS20_Training_036_90_t2.png" +save_file = "./example_mask/brats_4X_mask.npy" + + +example = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_8X/BraTS20_Training_036_90_t2_8X_undermri.png" +gt = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_8X/BraTS20_Training_036_90_t2.png" +save_file = "./example_mask/brats_8X_mask.npy" + + + +example_img = plt.imread(example) # cv2.imread(example_frequency_img, cv2.IMREAD_GRAYSCALE) +gt = plt.imread(gt) # cv2.imread(example_frequency_img, cv2.IMREAD_GRAYSCALE) + +print("example_img shape: ", example_img.shape) +plt.imshow(example_img, cmap='gray') +plt.title("Example Frequency Image") +plt.show() + +example_img = torch.from_numpy(example_img).float() +fre = fftshift(fft2(example_img)) # ) +amp = torch.log(torch.abs(fre)) +plt.imshow(amp.squeeze(0).squeeze(0).numpy()) +plt.show() +angle = torch.angle(fre) +plt.imshow(angle.squeeze(0).squeeze(0).numpy()) +plt.show() + + +gt_fre = fftshift(fft2(torch.from_numpy(gt).float())) # ) +gt_amp = torch.log(torch.abs(gt_fre)) +plt.imshow(gt_amp.squeeze(0).squeeze(0).numpy()) +plt.show() +gt_angle = torch.angle(gt_fre) +plt.imshow(gt_angle.squeeze(0).squeeze(0).numpy()) +plt.show() + + +amp_mask = gt_amp.squeeze(0).squeeze(0).numpy() - amp.squeeze(0).squeeze(0).numpy() +amp_mask = np.mean(amp_mask, axis=0, keepdims=True) + +print("amp_mask shape: ", amp_mask) +thres = np.mean(amp_mask) +amp_mask[amp_mask < thres] = 1 +amp_mask[amp_mask >= thres] = 0 + + +#duplicate +amp_mask = np.repeat(amp_mask, 240, axis=0) + +plt.imshow(gt_amp.squeeze(0).squeeze(0).numpy() - amp.squeeze(0).squeeze(0).numpy()) +plt.show() +plt.imshow(gt_angle.squeeze(0).squeeze(0).numpy() - angle.squeeze(0).squeeze(0).numpy()) +plt.show() +plt.imshow(amp_mask) +plt.show() + +np.save(save_file, amp_mask) +# + + +load_backmask = np.load(save_file) +plt.imshow(load_backmask) +plt.show() + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/k_degradation.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/k_degradation.py new file mode 100644 index 0000000000000000000000000000000000000000..1e00e6ac47f15879102935578fd1f2b15185af02 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/k_degradation.py @@ -0,0 +1,439 @@ +# --------------------------- +# Fade kernels +# --------------------------- +import cv2, torch +import numpy as np +import torchgeometry as tgm +from torch.fft import fft2, ifft2, fftshift, ifftshift, fftn, ifftn + +try: + from frequency_diffusion.degradation.mask_utils import RandomMaskFunc, EquispacedMaskFunc +except: + from mask_utils import RandomMaskFunc, EquispacedMaskFunc + + +from torch import nn +import matplotlib.pyplot as plt + +def get_fade_kernel(dims, std): + fade_kernel = tgm.image.get_gaussian_kernel2d(dims, std) + fade_kernel = fade_kernel / torch.max(fade_kernel) + fade_kernel = torch.ones_like(fade_kernel) - fade_kernel + # if device_of_kernel == 'cuda': + # fade_kernel = fade_kernel.cuda() + fade_kernel = fade_kernel[1:, 1:] + return fade_kernel + + + +def get_fade_kernels(fade_routine, num_timesteps, image_size, kernel_std,initial_mask): + kernels = [] + for i in range(num_timesteps): + if fade_routine == 'Incremental': + kernels.append(get_fade_kernel((image_size + 1, image_size + 1), + (kernel_std * (i + initial_mask), + kernel_std * (i + initial_mask)))) + elif fade_routine == 'Constant': + kernels.append(get_fade_kernel( + (image_size + 1, image_size + 1), + (kernel_std, kernel_std))) + + elif fade_routine == 'Random_Incremental': + kernels.append(get_fade_kernel((2 * image_size + 1, 2 * image_size + 1), + (kernel_std * (i + initial_mask), + kernel_std * (i + initial_mask)))) + return torch.stack(kernels) + + +# --------------------------- +# Kspace kernels +# --------------------------- +# cartesian_regular +def get_mask_func(mask_method, af, cf): + if mask_method == 'cartesian_regular': + return EquispacedMaskFractionFunc(center_fractions=[cf], accelerations=[af]) + + elif mask_method == 'cartesian_random': + return RandomMaskFunc(center_fractions=[cf], accelerations=[af]) + + elif mask_method == "random": + return RandomMaskFunc([cf], [af]) + + elif mask_method == "randompatch": + return RandomPatchFunc([cf], [af]) + + elif mask_method == "equispaced": + return EquispacedMaskFunc([cf], [af]) + + else: + raise NotImplementedError + + +use_fix_center_ratio = False + +class Noisy_Patch(nn.Module): + def __init__(self): + super(Noisy_Patch, self).__init__() + self.af_list = [] + self.cf_list = [] + self.fe_list = [] + self.pe_list = [] + self.seed = 0 + + def append_list(self, at, cf, fe, pe): + self.af_list.append(at) + self.cf_list.append(cf) + self.fe_list.append(fe) + self.pe_list.append(pe) + + def get_noisy_patches(self, t): + af = self.af_list[t] + cf = self.cf_list[t] + fe = self.fe_list[t] + pe = self.pe_list[t] + + patch_mask = get_mask_func("randompatch", af, cf) + mask_, _ = patch_mask((fe, pe, 1), seed=self.seed) # mask (numpy): (fe, pe) + return mask_ + + def forward(self, mask, ts): + # Step 1, Random Drop elements to learn intra-stride effects, patch-wise + # print("use_patch_kernel forward:", t) + # print("mask = ", mask.shape) + # masks_ = [] + for id, t in enumerate(ts): + mask_ = self.get_noisy_patches(t)[0] + # print("mask_ = ", mask_.shape) + # print("mask[id, t] =", mask[t].shape) + + mask[t] = mask_.to(mask[t].device) * mask[t] + self.seed += ts[0].item() + + # masks_ = torch.stack(masks_).cuda() + # print("masks_ = ", masks_.shape) + # print("mask = ", mask.shape) # B, T, H, W + + return mask + +get_noisy_patches = Noisy_Patch() + + +def get_ksu_mask(mask_method, af, cf, pe, fe, seed=0, is_training=False): + mask_func = get_mask_func(mask_method, af, cf) # acceleration factor, center fraction + + if mask_method in ['cartesian_regular', 'cartesian_random', 'equispaced']: + # print("pe") + mask = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + # prepare noisy patches + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, fe, pe) + + elif mask_method == 'equispaced': + mask = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + # prepare noisy patches + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, fe, pe) + + elif mask_method == 'gaussian_2d': + mask, _ = mask_func(resolution=(fe, pe), accel=af, sigma=100, seed=seed) # mask (numpy): (fe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (fe, pe) --> (1, fe, pe) + + elif mask_method in ['radial_add', 'radial_sub', 'spiral_add', 'spiral_sub']: + sr = 1 / af + mask = mask_func(mask_type=mask_method, mask_sr=sr, res=pe, seed=seed) # mask (numpy): (pe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (pe, pe) --> (1, pe, pe) + + else: + raise NotImplementedError + + # print("return mask = ", mask.shape) + return mask + + + +def get_ksu_kernel(timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=4, + center_fraction=0.08, + accelerate_mask=None): + + if accelerated_factor == 4: + mask_method, center_fraction = "cartesian_random", center_fraction #0.08 # 0.15 + + elif accelerated_factor == 8: + mask_method, center_fraction = "equispaced", center_fraction # 0.04 + + elif accelerated_factor == 12: + mask_method, center_fraction = "equispaced", center_fraction + + + center_ratio_factor = center_fraction * accelerated_factor + + masks = [] + noisy_masks = [] + ksu_mask_pe = ksu_mask_fe = image_size # , ksu_mask_pe=320, ksu_mask_fe=320 + # ksu_mask_fe + if ksu_routine == 'LinearSamplingRate': + # Generate the sampling rate list with torch.linspace, reversed, and skip the first element + sr_list = torch.linspace(start=1/accelerated_factor, end=1, steps=timesteps + 1).flip(0) + sr_list = [sr.item() for sr in sr_list] + # Start from 0.01 + for sr in sr_list: + sr = sr.item() + af = 1 / sr # * accelerated_factor # acceleration factor + cf = center_fraction if use_fix_center_ratio else sr_list[0] * center_ratio_factor + + masks.append(get_ksu_mask(mask_method, af, cf, pe=ksu_mask_pe, fe=ksu_mask_fe)) + + elif ksu_routine == 'LogSamplingRate': + + # Generate the sampling rate list with torch.logspace, reversed, and skip the first element + sr_list = torch.logspace(start=-torch.log10(torch.tensor(accelerated_factor)), + end=0, steps=timesteps + 1).flip(0) + + sr_list = [sr.item() for sr in sr_list] + af = 1 / sr_list[-1] + cf = center_fraction if use_fix_center_ratio else sr_list[-1] * center_ratio_factor + + + if isinstance(accelerate_mask, type(None)): + cache_mask = get_ksu_mask(mask_method, af, cf, pe=ksu_mask_pe, fe=ksu_mask_fe) + # print("cache_mask = ", cache_mask.shape) # torch.Size([1, 320, 320]) + else: + cache_mask = accelerate_mask + + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, ksu_mask_fe, ksu_mask_pe) + + masks.append(cache_mask) + + sr_list = sr_list[:-1][::-1] #.flip(0) # Flip? + + for sr in sr_list: + af = 1 / sr + cf = center_fraction if use_fix_center_ratio else sr * center_ratio_factor + # print("af = ", af, cf) + + H, W = cache_mask.shape[1], cache_mask.shape[2] + new_mask = cache_mask.clone() + + # Add additional lines to the mask based on new acceleration factor + total_lines = H + sampled_lines = int(total_lines / af) + existing_lines = new_mask.squeeze(0).sum(dim=0).nonzero(as_tuple=True)[0].tolist() + + remaining_lines = [i for i in range(total_lines) if i not in existing_lines] + + if sampled_lines > len(existing_lines): + center = W // 2 + additional_lines = sampled_lines - len(existing_lines) # sample number + + sorted_indices = sorted(remaining_lines, key=lambda x: abs(x - center)) + + # Take the closest `additional_lines` indices + sampled_indices = sorted_indices[:additional_lines] + + # Remove sampled indices from remaining_lines + for idx in sampled_indices: + remaining_lines.remove(idx) + + # Update new_mask for each sampled index + for idx in sampled_indices: + new_mask[:, :, idx] = 1.0 + + + + cache_mask = new_mask + + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, ksu_mask_fe, ksu_mask_pe) + + + masks.append(cache_mask) + + # reverse + masks = masks[::-1] + noisy_masks = masks # noisy_masks[::-1] + + + elif mask_method == 'gaussian_2d': + raise NotImplementedError("Gaussian 2D mask type is not implemented.") + + else: + raise NotImplementedError(f'Unknown k-space undersampling routine {ksu_routine}') + + # Return masks, excluding the first one + return masks[1:] + + + +class high_fre_mask: + def __init__(self): + self.mask_cache = {} + + def __call__(self, H, W): + if (H, W) in self.mask_cache: + return self.mask_cache[(H, W)] + center_x, center_y = H // 2, W // 2 + radius = H//8 # 影响的频率范围半径 + + high_freq_mask = torch.ones(H, W) + for i in range(H): + for j in range(W): + if (i - center_x) ** 2 + (j - center_y) ** 2 <= radius ** 2: + high_freq_mask[i, j] = 0.0 + self.mask_cache[(H, W)] = high_freq_mask + return high_freq_mask + + +high_fre_mask_cls = high_fre_mask() + + + +def apply_ksu_kernel(x_start, mask): + fft, mask = apply_tofre(x_start, mask) + fft = fft * mask + x_ksu = apply_to_spatial(fft) + + return x_ksu + +# from dataloaders.math import ifft2c, fft2c, complex_abs + +def apply_tofre(x_start, mask): + # B, C, H, W = x_start.shape + kspace = fftshift(fft2(x_start, norm=None, dim=(-2, -1)), dim=(-2, -1)) # Default: all dimensions + mask = mask.to(kspace.device) + return kspace, mask + +def apply_to_spatial(fft): + x_ksu = ifft2(ifftshift(fft, dim=(-2, -1)), norm=None, dim=(-2, -1)) # ortho + # After ifftn, the output is already in the spatial domain + x_ksu = x_ksu.real #torch.abs(x_ksu) # + return x_ksu + + + +if __name__ == "__main__": + # First STEP + import matplotlib.pyplot as plt + import numpy as np + + image_size = 64 + + masks = get_ksu_kernel(25, image_size, + "LinearSamplingRate", is_training=True) # LogSamplingRate + + + batch_size = 1 + + img = plt.imread("/Users/haochen/Documents/GitHub/DiffDecomp/Cold-Diffusion/defading-diffusion-pytorch/diffusion_pytorch/assets/img.png") + img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LINEAR) + # to gray scale + img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + + # img = np.transpose(img, (2, 0, 1)) + # img = img[0] + img = np.expand_dims(img, axis=0) + img = torch.from_numpy(img).unsqueeze(0).float() + print(" img shape: ", img.shape, img.max(), img.min()) + + rand_kernels = [] + rand_x = torch.randint(0, image_size + 1, (batch_size,)).long() + print("rand_x shape:", rand_x.shape, rand_x) + + img = img * 2 - 1 # + + masked_img = [] + + for m in masks: + m = m.unsqueeze(0) + img = apply_ksu_kernel(img, m, pixel_range='-1_1', ) + masked_img.append(img) + + masks = np.concatenate(masks, axis=-1)[0] + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + # masked_img = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY) + + print(" masked_img shape: ", masked_img.shape) + print(" mask shape: ", masks.shape) + + img = np.concatenate([masks, masked_img], axis=0) + + plt.figure(figsize=(100, 10)) + plt.imshow(img, cmap='gray') # (1, 128, 1280) + plt.show() + + print("\n\nSecond stage...") + + + # Second STEP + import matplotlib.pyplot as plt + import numpy as np + + image_size = 64 + batch_size = 1 + t = 25 + kspace_kernels = get_ksu_kernel(t, image_size, ksu_routine="LogSamplingRate", is_training=True) # 2 * + kspace_kernels = torch.stack(kspace_kernels).squeeze(1) + + img = plt.imread( + "/Users/haochen/Documents/GitHub/DiffDecomp/Cold-Diffusion/generation-diffusion-pytorch/diffusion_pytorch/assets/img.png") + img = cv2.resize(img, (image_size, image_size)) + + img = np.transpose(img, (2, 0, 1)) + img = img[0] + img = np.expand_dims(img, axis=0) + img = torch.from_numpy(img).unsqueeze(0).float() + print(" img shape: ", img.shape, img.max(), img.min()) + + rand_kernels = [] + rand_x = torch.randint(0, image_size + 1, (batch_size,)).long() + + print("rand_x shape:", rand_x.shape, rand_x) + + for i in range(batch_size): + print("kspace_kernels[j] shape = ", kspace_kernels[i].shape, rand_x[i]) + # k = kspace_kernels[j].clone() + rand_kernels.append(torch.stack( + [kspace_kernels[j][: image_size, # rand_x[i]:rand_x[i] + + : image_size] for j in + range(len(kspace_kernels))])) + + rand_kernels = torch.stack(rand_kernels) + + # rand_kernels shape: torch.Size([24, 5, 128, 128]) + print("=== rand_kernels: ", rand_kernels.shape, kspace_kernels[0].shape) + + masked_img = [] + masks = [] + for i in range(t): + k = torch.stack([rand_kernels[:, i]], 1)[0] + masks.append(k) + # print("-- k shape: ", k.shape) + # print("-- img shape: ", img.shape) + + img = apply_ksu_kernel(img, k, pixel_range='0_1') + masked_img.append(img) + + masks = np.concatenate(masks, axis=-1)[0] + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + # masked_img = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY) + + # print(" masked_img shape: ", masked_img.shape) + # print(" mask shape: ", masks.shape) + + img = np.concatenate([masks, masked_img], axis=0) + + plt.figure(figsize=(100, 10)) + plt.imshow(img, cmap='gray') # (1, 128, 1280) + plt.show() + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/kspace_test.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/kspace_test.py new file mode 100644 index 0000000000000000000000000000000000000000..aa490fe2ac1a25366b0750bd5fe3d4b785c414ff --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/kspace_test.py @@ -0,0 +1,274 @@ +# --------------------------- +# Fade kernels +# --------------------------- +import cv2, torch +import numpy as np +import torchgeometry as tgm +from torch.fft import fft2, ifft2, fftshift, ifftshift +import matplotlib.pyplot as plt +from mask_utils import RandomMaskFunc, EquispacedMaskFunc + + +try: + from mask_utils import RandomMaskFunc, EquispacedMaskFunc +except: + from mask_utils import RandomMaskFunc, EquispacedMaskFractionFunc, EquispacedMaskFunc, RandomPatchFunc + +try: + from .k_degradation import get_fade_kernel, get_fade_kernels, get_mask_func, high_fre_mask, get_ksu_kernel +except: + from k_degradation import get_fade_kernel, get_fade_kernels, get_mask_func, high_fre_mask, get_ksu_kernel + + +use_fix_center_ratio = False + + +def get_ksu_mask(mask_method, af, cf, pe, fe, seed=0, is_training=False): + mask_func = get_mask_func(mask_method, af, cf) # acceleration factor, center fraction + + if mask_method in ['cartesian_regular', 'cartesian_random']: + + mask, num_low_freq = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + if is_training: # Step 1, Random Drop elements to learn intra-stride effects, patch-wise + af_new = 1.0 + (af - 1.0) / 2 + # af_new = max(af_new, 1.0) + + patch_mask = get_mask_func("randompatch", af_new, cf) + mask_, _ = patch_mask((fe, pe, 1), seed=seed) # mask (numpy): (fe, pe) + + mask = mask_ * mask + + + elif mask_method == 'gaussian_2d': + mask, _ = mask_func(resolution=(fe, pe), accel=af, sigma=100, seed=seed) # mask (numpy): (fe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (fe, pe) --> (1, fe, pe) + + elif mask_method in ['radial_add', 'radial_sub', 'spiral_add', 'spiral_sub']: + sr = 1 / af + mask = mask_func(mask_type=mask_method, mask_sr=sr, res=pe, seed=seed) # mask (numpy): (pe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (pe, pe) --> (1, pe, pe) + + else: + raise NotImplementedError + + # print("return mask = ", mask.shape) + return mask + + +# ksu_masks = get_ksu_kernels() +# (C, H, W) --> (B, C, H, W) + + +high_fre_mask_cls = high_fre_mask() + + +def apply_ksu_kernel(x_start, mask, params_dict=None, pixel_range='mean_std', + use_fre_noise=False, return_mask=False): + fft, mask = apply_tofre(x_start, mask, params_dict, pixel_range) + + # Use the high frequency mask to add noise + if use_fre_noise: + fft = fft * mask + fft_magnitude = torch.abs(fft) # 幅度 + fft_phase = torch.angle(fft) # 相位 + + _, _, H, W = fft.shape + + high_freq_mask = high_fre_mask_cls(H, W).to(fft.device) + high_freq_mask = high_freq_mask.unsqueeze(0).unsqueeze(0).repeat(fft.shape[0], 1, 1, 1) + + # Background Noise + sigma = 0.2 + noise = torch.randn_like(fft_magnitude) * sigma + mean_mag = fft_magnitude.sum() / (mask.sum() + 1) + + noise_magnitude_high = noise * (mean_mag) * (1 - mask) # high_freq_mask + + sigma = 0.1 + noise = torch.randn_like(fft_magnitude) * sigma + noise_magnitude_low = noise * fft_magnitude * mask # (1 - high_freq_mask) + + # fft_noisy_magnitude = fft_magnitude * mask + noise_magnitude * high_freq_mask * (1 - mask) + fft_noisy_magnitude = fft_magnitude * mask + fft_noisy_magnitude += noise_magnitude_high + noise_magnitude_low + fft_noisy_magnitude = torch.clamp(fft_noisy_magnitude, min=0.0) + + fft = fft_noisy_magnitude * torch.exp(1j * fft_phase) + + else: + fft = fft * mask + + x_ksu = apply_to_spatial(fft, params_dict, pixel_range) + if return_mask: + return x_ksu, fft, fft_magnitude + + return x_ksu + + +def apply_tofre(x_start, mask, params_dict=None, pixel_range='mean_std'): + fft = fftshift(fft2(x_start)) + mask = mask.to(fft.device) + return fft, mask # , _min, _max + + +def apply_to_spatial(fft, params_dict=None, pixel_range='mean_std'): + x_ksu = ifft2(ifftshift(fft)) + x_ksu = torch.abs(x_ksu) + + return x_ksu + + +if __name__ == "__main__": + # First STEP + import SimpleITK as sitk + + import numpy as np + import os + + image_size = 240 + batch_size = 1 + t = 5 + + + + + use_linux = True + + # Load MRI back here + if use_linux: + root = "/gamedrive/Datasets/medical/Brain/brats/ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData" + p_id = 639 + modality = "T1C" + filename = f"{root}/BraTS-GLI-{p_id:05d}-000/BraTS-GLI-{p_id:05d}-000-{modality.lower()}.nii.gz" + img_obj = sitk.ReadImage(filename) + img_array = sitk.GetArrayFromImage(img_obj) + + slice = img_array.shape[0] // 2 + img = img_array[slice, ...] + plt.imshow(img, cmap="gray") + plt.show() + img = (img - img.min()) / (img.max() - img.min()) + + plt.imsave("visualization/original.png", img, cmap="gray") + + else: + # Or use PNG + img = plt.imread( + "/Users/haochen/Documents/GitHub/DiffDecomp/Cold-Diffusion/generation-diffusion-pytorch/diffusion_pytorch/assets/img.png") + img = np.transpose(img, (2, 0, 1))[0] + + img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LINEAR) + print("img shape=", img.shape) + + img = np.expand_dims(img, axis=0) + img = torch.from_numpy(img).unsqueeze(0).float() + + ksu_routine = "LogSamplingRate" # "LinearSamplingRate" # + kspace_kernels, patch_drop_masks = get_ksu_kernel(t, image_size, + ksu_routine=ksu_routine, is_training=True, + example_frequency_img=example) + kspace_kernels = torch.stack(kspace_kernels).squeeze(1) + + # all k_space + rand_kernels = [] + rand_x = torch.randint(0, image_size + 1, (batch_size,)).long() + + for i in range(batch_size): + # k = kspace_kernels[j].clone() + rand_kernels.append(torch.stack( + [kspace_kernels[j][: image_size, # rand_x[i]:rand_x[i] + + : image_size] for j in + range(len(kspace_kernels))])) + + rand_kernels = torch.stack(rand_kernels) + + # rand_kernels shape: torch.Size([24, 5, 128, 128]) + ori_img = img + + masked_img = [] + masks = [] + for i in range(t): + k = torch.stack([rand_kernels[:, i]], 1)[0] + masks.append(k) + + img = apply_ksu_kernel(img, k, pixel_range='0_1') + + masked_img.append(img) + + # Visualize the masks + masks = np.concatenate(masks, axis=-1)[0] + + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + # Save individually + + print("masks / masked_img=", masks.max(), masked_img.max()) + # img = np.concatenate([masks, masked_img], axis=0) + + plt.imsave("visualization/sample_masks.png", masks, cmap='gray') + + # masked_img = (masked_img - masked_img.min())/(masked_img) + # masked_img = np.concatenate([masked_img, 1-masked_img], axis=0) + plt.imsave("visualization/sample_images.png", masked_img, cmap='gray') + + w = masked_img.shape[0] + pr_folder = "visualization/progressive" + os.makedirs(pr_folder, exist_ok=True) + + # Progressive + print() + for i in range(t): + plt.imsave(f"{pr_folder}/{i}_masked_img.png", masked_img[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_masks.png", masks[:, i * w: (i + 1) * w], cmap='gray') + + img = ori_img + # use_fre_noise=False, return_mask=False + masked_img = [] + masks = [] + fft = [] + ks = [] + for i in range(t): + k = torch.stack([rand_kernels[:, i]], 1)[0] + ks.append(k) + + img, k, fft_original = apply_ksu_kernel(img, k, pixel_range='0_1', use_fre_noise=True, return_mask=True) + + # k -> fft + fft_magnitude = np.abs(k) # 幅度 + # fft_phase = torch.angle(k) # 相位 + + mag = np.log(fft_magnitude[0]) + masks.append(mag) + fft.append(np.log(fft_original[0])) + + masked_img.append(img) + + # Visualize the masks + masks = np.concatenate(masks, axis=-1)[0] + ks = np.concatenate(ks, axis=-1)[0] + + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + + fft = np.concatenate(fft, axis=-1)[0] + + plt.imsave("visualization/sample_noisy_mask.png", masks, cmap='gray') + + # masked_img = np.concatenate([masked_img, 1 - masked_img], axis=0) + plt.imsave("visualization/sample_noisy_image.png", masked_img, cmap='gray') + # print("masked_img shape=", masked_img.shape, w) + + # Progressive + for i in range(t): + # print("masked_img[:, t*w: (t+1)*w] = ", masked_img[:, t*w: (t+1)*w].shape, t*w) + + plt.imsave(f"{pr_folder}/{i}_n_masked_img.png", masked_img[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_n_masks.png", masks[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_n_fft.png", fft[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_n_ks.png", ks[:, i * w: (i + 1) * w], cmap='gray') + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/mask_utils.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6eda8a0397fb628cabc4e1d97f93ae9db37377f3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/degradation/mask_utils.py @@ -0,0 +1,196 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # print("center_fraction = ", center_fraction) + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/frequency_noise.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/frequency_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..275fddb66f47bc02036fca5d31ec121b55939baa --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/frequency_diffusion/frequency_noise.py @@ -0,0 +1,39 @@ +import torch + +def add_frequency_noise(fft, snr=10, vacant_snr=15, mask=None): + ### 根据SNR确定noise的放大比例 + num_pixels = fft.numel() + + fft_magnitude = torch.abs(fft) + fft_phase = torch.angle(fft) + + # fft_magnitude + mag_psr = torch.mean(torch.abs(fft_magnitude) ** 2) + mag_pnr = mag_psr / (10 ** (snr / 10)) # Calculate noise power + noise_mag = torch.randn_like(fft_magnitude) * torch.sqrt(mag_pnr) + + mag_psr_vacant = mag_psr / (10 ** (vacant_snr / 10)) + noise_mag_vacant = torch.randn_like(fft_magnitude) * torch.sqrt(mag_psr_vacant) + + fft_magnitude = fft_magnitude + \ + noise_mag * fft_magnitude * mask + \ + noise_mag_vacant * (1- mask) + fft_magnitude = torch.abs(fft_magnitude) + + # fft_phase + pha_psr = torch.mean(torch.abs(fft_phase) ** 2) + pha_pnr = pha_psr / (10 ** (snr / 10)) # Calculate noise power for phase + noise_pha = torch.randn_like(fft_phase) * torch.sqrt(pha_pnr) + + pha_psr_vacant = pha_psr / (10 ** (vacant_snr / 10)) + noise_pha_vacant = torch.randn_like(fft_phase) * torch.sqrt(pha_psr_vacant) + + fft_phase = fft_phase + \ + noise_pha * fft_phase * mask + \ + noise_pha_vacant * (1- mask) + + noise_fft = fft_magnitude * torch.exp(1j * fft_phase) + + return noise_fft + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/000_kmask.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/000_kmask.png new file mode 100644 index 0000000000000000000000000000000000000000..0ea73db1aa2e0ef80c9ab4f3daec7e8e56ca89cf Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/000_kmask.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..0ec08c5415f91c8543aaafd260a1e050e8d8ec73 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33a983fe87cb202b0d5a4533ae698b635556483fae6bf5271d45bfab7ea6efa4 +size 515323 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..495f679a26e2a23a70396a275495947026e589a5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98eafb8cf3770ecefaba3e189ab4f1dc441e87617cd90156b85e0b03391f44e9 +size 583673 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..1f57448057e8b08b9a8658cc994f5fab6c842aeb --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bdee791635d994c792f336844563f4b05c8fe062c2a784096cdd18e5f59db809 +size 502289 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..b4aa28ea041dac69423050aae93c97b74c460806 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8392c2989cdb4d40ab992758a92b4140e489c00763304f2a5698fd350cbef0a +size 589521 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..d0635ea1b3c591d99b8a28ee37b5df0c6702514d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d18d04fbe71c09e27022b0057b0cffc24732b4fc401a9cee31770a75170d56ca +size 490193 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..a06a663c0a71d7c56edee2492ba2f364e665da42 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/104150-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:82ba1f2d5a21cca18e5147541be4277163b8b3678f9e7a4d54f7db0948027d5e +size 581381 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..fb88b4b650be0f8fb1c545c15a4d22bc91e54831 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d58db8df051212ee8f0c43fec1e22feeeaa086d454516c58514b59faab1ae04 +size 515603 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..5a7b8be8fe4275a3b2d99ef503f566bde36b719d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a7e984d95d8a76431cd06b22d4cf3b5ad6f30d48f2ddfcb30b6acdc30f41fe44 +size 583110 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..cbbda99fba347fce708b864e96566f012e24e104 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2cc77e2cb80cbb84588c69bffbd183c47ad4981496fd4f6776adbd7f59b4386 +size 502746 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..f24ae4ad2321ab942836389dd4a4b89629da8322 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b840d0d262afc8baa57b5f02935b7a1f7f348773e52a1b01057498221f540bb +size 587293 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..b5b6d3e5f46b8dfc932bbf34eae0f5334700498c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ddfcca7f0ab43e7c95c3e61ed99fdf9691a46517b0e3778e6501ebcaf70aeb14 +size 490013 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..ebe2866ac155337047024ef263807b3753ed06fd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/108316-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f8e45ee920faf45a5588be39daba815f4c72b371c0dfd5195d37186d0d05ca5 +size 578757 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..db842f021105aafdb795cc3b585fcccb26440d0f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b42ba9cd38866269cb95cf6de285f10f7500152ad7c11001149ea6d9137f42a8 +size 515767 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..0303e9ac3088bf2b1a12e2ccfa96a64b51b13119 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:14b301f855c5a648c150d64223d5f801f488cf6c8cec53b63037dbe4a3d6e847 +size 585445 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..893f10b5646daa6e9a513ea7ff9d800878504c64 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:87da64ad59ad52136b4df473b28a9777685bb6c8ebe344ca37f4863fdb8bc60c +size 502402 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..57d3ce3870583588852d5ba0054814757843a70b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a24751d6549fef5d96403406b96e8c94dc3c732d120332a7285d47233d7dbe4 +size 588903 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..fd5f432b9e95937835c3635f6826d40e6710f0de --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2fd438f28e66c1da75d73f045dc3c635a6770660a151d8797136a83f4da8440c +size 489407 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..0bf96572475189e11e12c82ef6a768f7e28f4ade --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/112482-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:831fd0907b19742024b03087dc10c73b9f97c019f2dc74ac8d81277d14b75767 +size 580349 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..927a9c4f1212b1ed362e4f87ef74f1973df021c4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:377315b16007c179ca9edb6d785adf34ff14f713e4bdedfab58c366badf858c9 +size 515544 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..b4e58a15b583d5b15374485c5a45b1205a31b5c5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98263a0fd537a4e99851438888339e2793e8bdcf75360f84ec82fed9ac314f48 +size 585912 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..4dc8feecd5253b868abf0f912e066f66740ef98a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d1e774a44c4de9d2d1b5dae3e5371bbb4ca04e7f6b1cae923812b8c5d53da8b7 +size 501981 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..c26ffc79b9bc6eadd29f11dd07251de6f4c4a1ac --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5661a23c585a6ae7a6594cca276a8b0b04a8d34437ce1566737fd16529fcd12 +size 588456 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..e93ff60e8b68ea854df246fb6e11589388ba1364 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:686d4fe71fe12f85fce30f39bf72e06d96150ce82c8dd4a2bfcf2f2542218a8d +size 489912 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..fa8fcf9b7badd32099ea72d566964d25b410a4df --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/116648-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6dfe7ad516e81774d760ef6cf280634a77520628f1c985508f067e8050f7459b +size 579947 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..72a656e7f7a0d348d9f2062d3e4360e8174f1bb1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:143cc65062f5fc76aefca00270d4cf0c1e1c38cf694b627cbb30f3b2ed6ba846 +size 515312 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..4bf7e64b79ae99a071c7fa0933bf94c26f643c1f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66891737680f46200391f9c24a07b57866a9cf85214e082e367d40b208f706bd +size 585906 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..e7e56c78db6889372049c58dac1ae3e0ead758be --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b36b8cf11f6f8c3ebaacc04367ce695764aea55bdf8861770e5856062ae02ef +size 502234 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..56cfb8f579d254b47fa6499bd08d8b4e52205558 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42c2298eedbdf2fbaff3f9f19c578f9fa258476a77803f62ac70eb8a7bbd1b42 +size 589556 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..a804d8ab7602d0f8be32f8d4ed87cae037c19910 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:17ae75aca2cb14062f98a0b010349a2da529b366c5da70e22b1a6ea182ec882c +size 489194 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..a14f3977b5d2dbe5d4cb599929c24218a5bdd530 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/120814-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9152a4703f07fdfb88b44d4fe29c7711f7d5fe91bf4ebd33f17212587e81a960 +size 580803 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..7a9e2796b1fe230c5657cc760def81938d86ca00 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be5bff991467b40d2bb2331eaf7f074713acd9e9d294e80df9383c03712a0825 +size 511879 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..1cd16e6587b73ee3b120532ca787f7a1eb58c4ce --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ffe5f640bef30d05a3291966da85329791c605e371984da8c2013fb8b69b0bb +size 566473 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..005e9a5ba2c7b88749af329c436f54b90464bed7 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e99e2d8a515b31dcba71d4172c04ce6d02ef06f76f0723fd8cc136010d34f1a2 +size 504175 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..74b397e0221d31bbcc4b57c573710f632fe0223f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a67e2d4e8334f455ff1c721247e2d6ed5ece47651391bf6854d6a80edb9efc56 +size 578276 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..7c0ab7f8310985221a937428ad43c09eb5783faa --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1da25f45d6175bbb0a07803d934ac198b64c72754791b273871b3d0549912c9b +size 492857 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..e0a2ac5526426b94903e9bbffae592eae6676bad --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/12498-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45d31d7a5cdbf345fcb10b3ac97bf8963c9d0ffbc68c8b042fb578c80b24ec93 +size 555681 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..47bae6387248119ac9dca4f29fe0b30509a1800c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6456f61495922d16c3d924bc61836775f9b873d7a48e740d5986ec0e85f3168b +size 516134 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..de2bb062b1ec72a9a17d4949b1fee8d51b68d6f8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ee33d94b0bedba853ec826f8701cac0385e044a605ee7a52bc4049db02b602b +size 586516 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f2c1ff898c2a006e75e4c5a35827e69c7c78fded --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:397265990d1582cdfbb7ddee2f163a3419a731df7440e4870a27a029905b17cd +size 502469 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..2c629f6e7aee1fdcb833013926252c05167f1213 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b8d7c9518228b3a9bbc4d3fe15bf4fe345f3c581f87ec291c849690e0568557 +size 589030 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..84eb4d1655e6b0d3afac7fbc11f2a09a94b9e9d3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a073062e1a3a4153712be017b2f02d3fa7c3255cb90c1f2c1768d7ca4f9f41ee +size 490354 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..306853acae85bcad054ac58160584b87ac2cc916 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/124980-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be3c3dcde0682bb88b9da30052f217fa30a526c40ae60bd7d706f9c67dd28960 +size 580303 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..64effca1215d377dcec27af9d20ebc8092645b54 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f172a40b50d258658aa2e5fef893aa93cc556db57fba8f685bd58688a2fa0dd +size 515589 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..1fed1d9c8d98a818981bcce70cbacd56a9a88138 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d8fdaa6a947791e631f571f711955289d0bb93996e4a2e9c594bb6a3cf1ee19 +size 586035 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..a7272a8839c12dac94816331800db988474722e2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58d142d29654317bc827912e5f66aaeec521d613f7c28c31d6b68c5f1cad6fd5 +size 502805 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..4567438550e21e0f2681bd7a77aed6526e5efc7d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d5ed92236f4075affdcbe6996c8d09a8c7b9b7d965e96a2d2db9e8202f275481 +size 589266 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..e31b98ed61b6b35627a395f03d503b864e7874a3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2279f4854cea7bf76ff8949d3adc990ad9008e3abd3dae9f8e50c42dce3be601 +size 489587 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..d36eb2f5a6209b1ed95d19e217b2b37874165b33 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/129146-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3261530f8189051f9dd7da647ff3cc112a3b24dc49503f9401484219d5231764 +size 582328 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f585c361f49250ad623e318711911296fc40e2c5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:376fb677341ef130ee0497147f1cc62284e085795c3f92e8ff202d53a3f0ccbc +size 514456 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..1682751ed2221f6fefb83833a050aeb65b3377da --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28a76b290645676f84439e2d4f9b022f47813f785f46c52da4e45d8283e5b301 +size 587103 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..8aeef1dd5fb1a0180965505b35b152deb10c806f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ccb5fde35d2136b68b97b1a996debac9aca4a3b1501753a6c1ddaa58db6fdf5 +size 500918 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..93aa9f8cc985f3f7c2da64e0ca7778ce0ac39d49 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2bd27d8a47e44292665788ed4219975072ab28985bdce57a7a48de44355ec8d3 +size 588839 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..725737447240485a3dd7757a803d0d6c8d4544d0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d43fe8d6276772b710a9553fae1cf76e48e7856c24073e0779efb222ff93064 +size 489028 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..9440a01219b84a68ebcb32b48917abc0cde1b45c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/133312-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40c2758e17a2fbb511bde5fb2ed5c749e5a8990e08920c2bd5e370d5840f47bc +size 581239 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..eef566c0d904ce5becb2465aaf3701e107ad97a3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cba0451e417eb71dbccf2a661fec9b84233ca8f00285dc99dff87e3424353cf4 +size 515662 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..0a5cf5c08e48106ddb92ced2adb4a3e5c031289d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8e0f95e149d2fc1736d78e7a42f3847a393a1cd10b6afe3dbb1e2bf808a4f95 +size 586198 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..85b77ba47b3752625ae6bd3cb9a1b510fb3c3346 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1137e925a5fc48f3e3afab1f91ad5c82a27d80fed5e1d439de8e4e1244d31e9b +size 501629 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..08a6d5cd6bdc0c1b31747ace0021ee419075eace --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21b5eac8dc4fe48910da28e0291c99cd4eb0384fd824daf84a9b86eff1981765 +size 588704 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f29d0dfaa791178021d7ed4b139a1c38442f2733 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:714512d216b8b3179faa6040be639352b1b0d2d311f478f97b67e132109d820c +size 489760 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..1cc228fb007e5c9427baf2b83cb12c14c87e3011 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/137478-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a151dc4495a01528392312f0b254f0bd31d5e0f8bde1a68f222dee019e4a90e +size 581052 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..bccb0b909d1977c7f9cfb546054c6092f91c588e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4972c5e93db2c2f0f35515c31c6d74a020ecc303639746c1d3c2a572247e4dde +size 515341 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..ecd78129007b2b923cb21870481727c742aa7346 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d18b65204aae4c6c5f941151d524f0997f696958dccef611db5dc3a394263184 +size 586359 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..67c644e3cf4fe6c61bdfe4647775009eba421147 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f1c8f9805bc3fd6e0b771952a92a161ff00fe3f5c21dc15dc8362bcbdf6d628 +size 501918 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..0f3ee7eae41257f8e9b050d3d13769b89485b26e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0de91192c357967b8dc645ce521d88a2b4000c25f46c2c14587aa3b33410fa9 +size 589238 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..a3701310cb27da63ec03905adf49b2d92eb14872 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3168e48127c86ff7e6b9720bb959c9adb6cabab8503c664222a075e608d00b68 +size 489258 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..47a6f5dfb2640b925b1f71f85221f1ea7ea998f5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/141644-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be154f0b37d16a8ca852d2ea0d6aff1a3750542024b21c5aa60219bca2a5e454 +size 580509 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..bcabae1546b867477f8bd4037b9d370e42e793f5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64a4d814c26bd94b66c9839a2402ef3e1ddebf1d76fd29607bfb42a164845744 +size 516213 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..24049d631fff322af1e8080a69586f1bbf7c7694 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31431efdf7331b70adc61106308e774ead4745fb57397d4f5c30b0fafc1b8213 +size 586868 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..ddd04b58ae7c0dde5da11ea91ba21f1f56ea7d80 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3def4a42fb4f48d06a14a2880023f7232e134526dbc85a3f0bb8e1b15ea492d5 +size 502731 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..e06af66900ea040e716666ccf21e90da3397ee0f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8450c90e98282ae42179e525f0b5c04b5ce79484c11b62685c75248df12a7ab6 +size 589988 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..6f9b18aca8e465275df1db7448d2d4c6c3c6cfff --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:710854b6f075a75fae5262c66ef8cf7bd7bfc8ffe3df4b7a85283762a7805e3b +size 490050 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..780bb5abe4fe232738016ed3a87a5f3507e1df76 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/145810-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27b960a6c555da56bc238469bba84b4902a4fde923dbed8d03cedc7906b3c43e +size 580532 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..bfb7cc5315cc008fb59e61287e5823472b0c1ca9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c6febf5f5c62a238dacc595ddf946e12eab0f5d2abcd2a68cfe0b4024de5c8a +size 515783 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..5292d027dcb06b5ad96c76f4067765a683929f63 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb1e6596ff373d4b0189f9d5413fc99830d5f452202310fa69054665eca22a53 +size 587751 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..53667fc1cef7e5e0d847d8285518785b524972d6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48cd9c2f88ac2f99772bdca39d26d0665962fbf7fc1dd00268cbc497e51c0a58 +size 502473 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..a1139b4afa2131935c2828c40a452e0d4467b722 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7cc3c25e0392709ef4e967477f8bf5aaed192f5951ee25b7bb9b7eb53dffd886 +size 590703 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..bdb6d76d61d9a0e9d39fc6d1e1bc2f0dbfd068ba --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:13f1647a9035ebd1905bf4e39e0a3008e1231c92964210fbd787938408e700f6 +size 490016 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..012913407eec29d25e40794e5a5161397fe04236 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/149976-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e88de8d5a93566868a9cd89c6597e77f8cf11ffbff23562b1e9b49eb0c41ead7 +size 579829 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f3b0f7902bd21a6c7fda47409fe0702bc11aebeb --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a138bff59f3f4a7e437281f56fe28a816e5e0d5aff7747394f64112122b6c9b +size 515287 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..98b37067529b77bf90b90f14e6ded9e76bfa33fc --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:480dc1045047bccc28bba500a88edea60208ca3419b3dcba924215c1f738ee6e +size 587530 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..b479786959eb919f75a7258fd48ddbcd59693928 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad55b4ab87bf61b917065e1cc8c5fb6a923334eec4bcd71220a3b80f860c371b +size 502676 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..1b0a8442f9c1419d8d14160933311849083519a1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70af64647818e0783e65fdd5762c7195c66b10c3669ff06a30cba5c6559abcf8 +size 591686 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..aee120773156ab04373f2d516611536b231c6e07 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d355a5d53e041e53d1b69eecdca29d0d648145e93286f243301477693685561 +size 490031 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..173d8d932d762bc750ac311dbd4f0a433a8ccf37 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/154142-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c600e3409eb40ac916752bf03620f993bf75b0e97fd9f234cb7ec4fbf7265fa5 +size 581795 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..64b2913911949a8c5a42852cef6fd45771687942 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c09e293745c8100af058c946ee9d9735490396a1862aa2e3706f4c77be68d78 +size 515102 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..a88ecee1bf9e342bcaa43685e586661fb5e73641 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:561f43512e266ff10f80eadcfc90fc339a23df9b01d91f39ed8870add3f5e30f +size 585941 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..cfddab8e0d2b7a123483e976e2e5b7a28c47c35c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5bf1feeb3188a2b1780cd6d24947eaf4ce4319d126ecf39503531b94ff5f95c4 +size 502404 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..fbbc74cccef0945afac11b823e9623e300ce2747 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924a07cedbe185ba37c73dd1f032a1b6c5c0062f8b1a7c386a8ef4ba7d6339ca +size 590630 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..dd7de8a32698dc8e86deda962d23462da7a6f87d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:151f75f56dd0ebe8d73bd0a3dd4d9e36e57206db05be2e0b425a2b9d69c78971 +size 489971 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..65c0cbfe65a3e24c955f365dcd3d205e39467dbb --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/158308-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b56eb22a5a44ad9a332aa10444c69505ac5bd70c076e460442d2f345fa22b4d9 +size 580115 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..298de21ac07ea575fe58b50e92c8a5ad5e27d251 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e6f0ad7137835f52acee0996f6a1b664b3209afed44b1fb080508cc7aaf2619 +size 515343 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..fbcaddd3ff2bf6082c657394af17f217f28f5d72 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81c8592c3f01d38510aadf5c86eec8053e3a7086679421ea6e23a9899183942b +size 586197 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..88042dfa9e9b89342573eafac5765abdcb5741c1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d1d2351e9ee6224ee7c6bdf0405bd5200d95ae57a4daf8e8fed396618f7e2129 +size 502614 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..aa7f62d13dd67bbde42947d8322a9d7834409c81 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96ee4d04ae69dd4b193a820dbd28d5d6a1c30e931a1cdabc501e0f5e903eeb8d +size 590438 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..018c28bf804fe943cbd88fa30f146a699ba7a413 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:771ecb8f845292f49aca44ef03619ac6b94e76f0ed9b7c7451b5dd8f25b8da99 +size 489369 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..c00319b3587d4349f993c8b350ff2a946ff28801 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/162474-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51511859890c01dbf7ca57d4d10253a5f673c244e62bf5f0e42c5f8f7b7ae605 +size 581128 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..de48fc50dcd306374bb03f927abcea91ff52fd16 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b91691d84241382d75f29e106e7945807683924b5e1c6553fb15e590dec2a604 +size 519077 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..b72440b137a08397f47082855d68dbcf872d13e2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31d114ac4b8b5e67f1e1d2f9d3e7d9ae02664bc30bf607905d0a8810bb9c8226 +size 581930 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..54c44fbd20755017c9331b4c36a8198840cb3650 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:199c3861d316a870c417fd170f73166784dfe20b2a3a81ca923f0390132cd5d1 +size 508150 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..3731546bc551f3cbab123227401903c84f9832eb --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e8a8c28f8992c4637daabfbfc03d2c0d5eacb3a71009f5a8839b0066b7f8d81 +size 592268 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..e139aa68125a571290e4a9920383e619d573b6f6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ca5ccacc826b50f462a790bf80db1dc8420d191533727ef58a00688764fbcbe +size 495402 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..b2cb6ffa25533b1925a2696b6fc72ff3f5045ce4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/16664-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c97cf8b8945d2f7f97df774c13f9bbdfbb23f4e78dfcc41eac6f353ce7fff486 +size 567064 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..9ee703caa2ee44b2d7a1cd80c6fa1c6bd04c2326 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4ece01dfe52dcba902f3352cc6880d11cfe266b549b5655ca2f45573e0c0133 +size 512543 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..2b8621ddb873e6c56eb3dfd68ab02f8a8669cdd1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:926d47d7295a9db4a25f1c3dce00644428d77036d28d8a6268f3198292ef62ad +size 584007 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..4399813448e0c89514159b7538f485ec5f2abea3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:123c6f977496bd21090545a623cafa4647d8321db9496fe79536033d3a95256e +size 502348 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..5904c17d98479fa70ffdd966c41423cc0374dfd3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:05e6d71cb5aa7403bf259f5842ee5a824a7c3c67cfa78782d800f4de25d58ff7 +size 593397 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..3107853d7d64d04a8fa8f602c9b7a18e4fd6d29e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:906543a162bedbd0f43871cf4b97044483b4b2166775306badd7187a236b608f +size 489776 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..63266b38d99bcc2c8124170dbd43adca98f5a8ba --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/20830-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3a6273d8b70db34fcd2912c4fd8d1e0b1db85eda5da5ce51e908e6e2ffd623eb +size 573606 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..c6a7c25d4ac0aa2fbbf2cb88e584fcc03c2d381f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52469488b03a9378dfad1b201ecdd3f06d4a2d5e02379fba1d1fce3c87fd12b8 +size 513797 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..371825286adde64519a1a9caed0824539e545a83 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a181f3631cfb6375a38369de379181781623c6b2c13d115cb48936dd01af9ec +size 583928 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..2cf5255ed924a18a7c618c2ba033a17fc3122281 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c4e57e9b1b37cf09f1fb548cee09a330996c8b3f428500f60fc676663f87a68 +size 502797 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..b7ea6a385d90a6492372744e80cb6cf04d46ac7f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb3e9682a46c85b57c56845f2a0057854bf61d71ec4c755515ce20f91f56e65a +size 597636 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..552537d275517e3e018f172933d6e22219e619a9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec59d89cea293ae0f384d7c2544e3e2d8765b8d79b9a823f941553812da6a257 +size 489970 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..607e80618f9989470c12b76fd0b5efa6bf86db09 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/24996-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51043dcee06a4e4bffe36a54c725300abf7119bcd211ca48f28ccb829603160b +size 582534 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..3c23e256dab3e05d68e6a86a74d1406ed02106c8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e34ea8d4ec60722d83154b8147bcb41b9d67fdf8fbbd81e756970e91c1954ad +size 513984 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..f8155c27a7ab60cc62b61b4afc09ae55fa75c76c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1701f6a636ddf9afdae9e073970093a0fe1012c74c924deb791f27856464353e +size 576219 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..b69a412e61641d7f1ac1d5493e8f630f9b8ff31f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5059915a5a7b0a2ee89fbbce6abf934a4c91757300fac42845399758488fc02c +size 501326 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..e60f2f6beb191dd73e80191923b8474530a39f0a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55caf35ba20cd3b31652b9e5696861b175dd337a2acf867cd1f56e0df4d56229 +size 590253 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..e20e23d6d95f03b8213383e410aa85e5f1cda1b8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f350d291f6460fe2fe6fcc1c9aa47085979a41f3ce364436b0131eec9608c73 +size 488085 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..38d924ec4e5cb07ee3e1ebb18954b29e6fa57631 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/29162-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b6e7490dde68313bcc9a076b4b5adc1a14dab8bcecef867fc38ba4ee2545c002 +size 577108 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f78ad596c8f2379a651e8038fd7ebee43e170450 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e364a6a164e7b4af704a9c28a63d03a85ab60026e2892608061cfee65365b454 +size 515659 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..14d834573b9c676f747c057b424fad6937466699 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5865146bf2ed3efb7b9a15224ae221c7702839fee5b26c89601ee861cc299de +size 574748 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..ee81b791935e84327c7d6c12d57c67f8d9dcddee --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf0edf4f090c368eace0e57e1577c37183270e59d314173e1cf20bddfc018ee7 +size 503474 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..92f425c698940b7d6ff6ae59476ad521dcaeee35 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1bdd95a25541684dcc61d2ea832f06378c2cb7a5e762a71d9209850041fefc88 +size 584990 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..28d46248c4453edc60c2a55b7c3b595c08c8246e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b53f85713704fed17f89d67cfa3b0c0141fe10d1ab414bd999e86ce71d52068b +size 490042 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..a6c630868f29dfaf46f72c91d8489fa2fd106b06 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/33328-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad73aa0960e93abf2987a5255c1a96ea0ef6e3963989046aa5c0a4123a7d85da +size 573509 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..0b97b45e831af8a4f05d72d52cf99730f20b34b8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4a9edfee9e6c41b3c446de333a3492ba11fcfe7e6cd8b835f98c2813fa33b2b +size 520611 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..4db5e820449b282d0db640fb273d00e95dc10441 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df74537b6090da2389f25dcf285ea9f984452321ac5a42c244c19d0725315fc2 +size 588357 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..7e0501ac38ea8ce1ccf2bde8a060f1bbec3099b9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aefca2c0aed37e0c4086e4340dbeb370e31025bf913866b42c5694378a1925e8 +size 507186 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..8c2c5ac8f6debdf0e5a8e5df6808d0b79fa9851e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:510ed1d5f5e8ba26a3e5471de7e12d460ca1d52f8547075832ea8b97c452e6ea +size 597117 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..21739f3a1a33e6cd902760e90d434aca7aa742f6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3da4feca192c85b60fe5992f52388b0da944a68cddf8743e3424a4422080f0e2 +size 492194 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..181a36991348b9d1032d1c757db45435318e351d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/37494-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62d86d8347b8bd19c41a44503b78671c23ef25ee0ae0d050c761608768bf92fb +size 585387 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..6ff5cfcce15bbed748e3d1d3d477c9725fcc4616 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d2dbf75c5b7edbeeba2123ad059b8cc2952e807210aaf99544e8c2e85f9075d2 +size 514237 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..919875a30a06de5867982eea0b8e32716dfe29f1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf89bff22217ec79e7c6cd4461d534bc82a7e790eff64b648a9bd7c66cc09ddb +size 579617 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..8bffbf049a73240794ac2c722bb3ce2b778e857c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66f95fe1a9edcccdcd5880e52a68547d95b63885f5154083edd29360ace48d6e +size 502764 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..8a8b429028688d284a57cbf2a60f24003fd6c23f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:49623bf4e5b5b00a22ad902a8f8b909b54bef2a073228cfb28e91ee804053fd5 +size 583436 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..5e8e726b522ac7066558fc33f7fbcdca2efb0e13 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d56fe7ffc7f6cb10a8ee1ae2be3f2de4a908ce5fee8ff042c13b81cb671c3755 +size 490026 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..1cd0ff107e72febe1211519081d8b5882f177d93 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/4166-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7eeba81525ebf549f363ac7232ef88197f1e735d4d022a53116a74b387165e89 +size 594157 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..c02f0d2979f8205ccb15245e4fd52aedd2b98b58 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9da4cdfc5c7061adee37fae627882871bdb23911523d10e40fc97b490e912179 +size 515284 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..f9b67fd07d9067b582416a70c1e6931cb52a4b47 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f36a64c7364ad190aff2190bd2cd0c676e4bd5629fa6c2d96beaec0b1d83949 +size 580869 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..d926ccee88857eb5542508d54380cdeb43fc3348 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b4e5a5b7498dc63148b61c385fccad88fc481e546c5a4ff7da888252b23aa6f +size 503120 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..ec7daf6810f617bc27d37643b98ecd5857f590fe --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a90d5bf8deddd2f18d98854199f09b245fda2f6df577bad543d7a1b8a647d83 +size 592305 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..8d3cb208b4769f0aaa86005b95df8f24484cef6a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33efbac776ce43dba6482303cc1e8b44e2b2c9136b18983921890305c23b89b4 +size 490322 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..294de874ea837b361d58cce7f412ee251a881509 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/41660-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e310beb9de2d130daf20c07aba461c2bad1abb6a51272b03d419da479e337d5e +size 580661 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f33b642850aefd87695ce2572eed2576156f2d01 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:35618a8c8124a3fc7bdcab1bab57639c449b5a8bcddbff2eb48ebdecae3b10b7 +size 515799 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..c04d17556a821bf28689d0c3b631605a4ac3a006 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0f3c1bcca463948b1dd8a4bf5a2764397e10002af48202b197b3ea8707d3f9d2 +size 579936 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..81d31843cc3f6c5891b3b33aaf6984a835f0f81e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bbd1f686554a37994c4f23d4bc58d0e2b8570b54b1029905df76e0131d1be0e1 +size 504146 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..51e8d837def94dae27b9cf2bade692c6da5cade9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6477385b9669f33345c82132b87877b16ade0a4bc2b7c82878a3ff43e02de32 +size 593961 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..e3f1684a924d30732d04fbc7349ca2d185c672f8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7541c78ea8ef62d48fb52698e2e5299b95e1f4f7af24ccebe5a4961c381643ce +size 491009 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..adb9319d02d0c82fd6d1c31824d5613d7b8752f3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/45826-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ff39cfcc71f9abfbeff890b50366c15bd2a6ae32dc3ab81e97fbbc1f5a6e17e3 +size 582583 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..2c69d159a8931decc656a3e0d5bdb842c3aebdfa --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ecca62905e57b40c967827f040427c17568497b44868d3089913ab68aced0510 +size 510302 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..576d0efa3611d6b27931d5865f99094a454e965a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:734680057ff2265d20d03df3ac8d96fe63893270db8057eaf7b36d19dd9311a7 +size 576344 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..03d42169303b60bea1fb5d11e7237ff490973e07 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7887ed15327626396806065a1736873da7bdc5f35516029759bea7bc240a894 +size 498221 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..eeb324dcd55f09c51c93aef752b4160fad44381d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03c56657b5b4bc84e476a6c6109ba9ccf46f7bc827e5607871f6d460a03896bd +size 586974 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..598385cbf6c0ccba870b61f88d1542f25cc5c1df --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1941c6baf8c7ec552a1f1c44b29285b67399aabe9c7fe9ff6cc249d7db83d85 +size 485786 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..73921ad877a6c35f0daf2585aa6cb3e5a5279219 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/49992-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac95c912271150ebd601e5fda2af8e2cf3a93b877b62318900e8e017d632c763 +size 579737 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..9e753157885a75f9119b4ea9b9f0a4969bf59bd2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33c3f6d9cd2188e7f6e39048e2c1ba4fe2bbc38b51dee8eb1a1dcb060c422708 +size 515821 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..0a41f1879de901f2eda508956bb5ca2861b54090 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47e195ee9e2511f89b1d246bd66980c6c56fd119d209e61cd3a843e33b5e8c41 +size 581193 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..6c694cc355f5aa78f350d1ae018eba899bd9d0ae --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ede9af39610fd99d3b9db183bd2db281871af979d1d666be673fd36c0f8dfd6e +size 503604 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..01aaf657b8a7c8b11933a0277a9dce1d8047ab11 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:015e4ebbe773f5f5a38c22ab5427dabfdef2edad4b91a888af23f7db915ec4a5 +size 596413 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..37f6409abfb7b085d78f22a517e515bc284648f8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bcb74d804c57bb2ba866377f8b095c539a248fb2505b11801f56cca33e8909a8 +size 489979 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..fe86cd22330f1010f5e3d029cca10dae21edf688 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/54158-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d1752143c5c5e6aa4afd869fcb021bf7354349c350501248227d1ee876882db6 +size 582507 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..64d5168ace696fde30bf6aa309811ca2af3d5e96 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15c61658eaea880bf2fdcba1032a8b5392ead070773e54fb5ebe6e174b30e2c9 +size 516386 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..487e40233d4cfb54b781c8d016c56558fd2d29e5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c2ffd33c9687ff076cec22d49a88e90487ce2e8e0dad78c5656c5858c6983e7 +size 578865 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..4814095d8534abaebf8bb9f647526dc0a9a425a1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4bd04fb7356659f4aa28933faafcee49c90a76a87ed17fa48f641b8c4154c60c +size 503702 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..3f5bd144db4525b681fb42f38e22a2e1279310fb --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b7b24e67f5dc74bf33aa340b0dbc92713ae74e34e5e7153c3aac82bae1b5957 +size 592247 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..918a2ff74f299d7683dd7c858acb08c062136f97 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4bb736bdc374d5607b42bdc3ee5f67826dfdddf8baeb8591a7e2829ff0bf262f +size 491150 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..8548d768fe32d44f1ccdc36825cb57d1bcc6eb5f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/58324-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c98b5f3cfa4dc79c724603477645269aa3f2475fa778ac3115ca5d0afd544515 +size 587266 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..6f9b030a160d090a4932545dea8f666c836e27b2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:30e0481c6034baff4817061a57e5f6dcbb8513277347dee37a8e7faf115596d6 +size 516269 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..2517f49f710ac1eacb318a356e4f725b99c3de61 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:289d03ad8a1e3fc4a49bc80daf6ae7492d8bfb12ef5cb7253ebaacb8651fe30c +size 584318 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..a5cb49fd3acd497bad73a564a58997b3b7d303f3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2480677d563db19097423ff425f523b882a945910700cc63cafc43a45e53a2f9 +size 504099 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..a5accbc55b93452c75a4679b364b866b2c409df3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:613e17ecdb36f06bf5373ac8fb385ed782f7dbfc8b8be0d46d37ba93286116f5 +size 586895 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..17dbbab5a34c2a15a79f0704ed39c18ab9c61070 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb9bf7dba5af899718226072432c7b59aca8bc2cb2a36e122ee80616647d5e13 +size 490671 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..efc1fb17079eb40f42a23bf61e56f603b7e1c223 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/62490-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:208a5b82aac75320286065baedc874a9cabc5802f1e951d546d9572f617a19c0 +size 578956 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..889754504ad7734e83a7a84ef664cb9dac15c3c8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b19b3d1d28d09b843862b31639fa8bf6724649ccb482fb4040ba2384adde99b0 +size 517983 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..65364801d08ccee3c3fae87f66e2799d836ec1de --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50055360b89bcb6f78e591fbee53d2efc68468d0064fcb898c1cd2935299f96e +size 584044 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..592b2b8dc44c3625fb432b47dd56817d8480e83e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3b8a9761eabab92b231d062488b5e606ba90129edffe755ec674e0738ef369b +size 505291 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..84d08d6d0c971bd9259b74a52b425a39e5778577 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c770eb6a63542d57c4286c14fa26cb5ba105c5f7d6525e7d0318b287103711ef +size 585335 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..96d75b54b3d57e1e1fb5ac38d5ef56dd30e6ad1e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5eb1a36e5188b076f4cb68d8ccd56b5dae0bf7784775bdc18f4d6ce5e47b987e +size 491916 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..abd747afd2f3d9a0843ea0de2206e5b330441cb5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/66656-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a43e9cb329b3cce1c2d7ba10966d2c62d80246b981f6ba162270458679e38a1d +size 575986 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..2fb96bd4ec50894fc4e49951e0e4427b9ee9d56d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:790f0e1f705bf0d3c6d2f0a7555f6a2c6ef63fbe587727d9b397b0f184880323 +size 515209 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..700e33f158691ed24f7bc7a6a8587aa61b7ca53e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:299c14101c075826aa2301be12eabd5e075daf560b41f2a960151d9aead64d8e +size 580329 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..2005a87826504df3be18b594ad288c52f0ef180c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f5cccb3cf39fee275b10c1294d9d6f5f2090c5d641e5a90a283cd2e27691cca2 +size 503097 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..257e9cdbc45c960280e1e26f3eab38e5c1bf2824 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:473715350308f1f0a8d3bce2ab6c3453a25a9769afa8a138580a8ead67374b71 +size 587346 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f03153f45d5506d426f3d739fcd2c40d22a54cd9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4573979de40baca986f59440abda8c5b89800d032b882446b52e3b6e8a164da +size 490536 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..282fd513d6b6d23773c459c93b47c579011a2137 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/70822-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31dd706fd14c72c451401fd2b648130e03cf7545e03771f5adf8cddc23122e4d +size 576675 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..27f3e31af6a42fa9ae4d237a4fd99adc7ab900dc --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53f0b0e642ce1cef246d7d473fe33237986bba3cf33eac72a5c4dc1d2e473840 +size 517694 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..854798cc059d772eef276f78cbe17254fd988896 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1caaa363a4b709317de869f165f0cc18bfa31617c9b05cd002762f5f4775a701 +size 586513 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..9fbb1cebc7714e6694699699ff953f0b1f3ed0b5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44e3700383c952dd7dc9f81eb75b2f3c72c338ca27c49fb1f8aa60a092532f96 +size 505309 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..720b00ee2296b01585267075ba04ecde29e86a75 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:081146d00daa575bac22c3ea322cb159c83120898e5657cabffecd2a87c27b91 +size 592185 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..3581795657620dda43ee0d8c67c182d43707e1ea --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:897d26ac1235ff582747d0fec881c80dc30eb05234893738d243f951e0cc5503 +size 492202 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..42bd4b558a04c76fdef7f0b04b275743883bd38a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/74988-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fec01878937732a494c178b57d3f595b667190813faefba6cf13a5567cd132e6 +size 580813 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..0171e458427f3b8249d90e400f016bbde51b2c48 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1cc76eda6ba00f1cd6ecbb7aa6b2e36930b80573a995b2a5f96c605a6f48f18b +size 515109 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..17e11b10d25f246c249f2b87e0fcec611477d9cb --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:973607badcca5cc43fb89706db3cff6967bea31438c30688c409e3d939d64913 +size 582428 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..fc9a56636151a90f4460e2e40a3e934b3ba99a87 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd320cce8a110c02d2377c3227f32f4796d74adea66cb55309b53d57f9e3ad87 +size 502164 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..198d781676cde8c21e23189ab238a81ae1df4268 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:92762f9d61b994db8f25cef7262d65b6833e19840188fc0dd2c05abbe72c7b06 +size 587804 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..8c8fd296b8959419f189b4bb9b8588849123af00 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1033c235a3f498bd48335136bc051d4abba996ac016e6e08b7ab7f724afc98b1 +size 489729 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..f48778b002121ebf35817db1c3ab6d340244e096 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/79154-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1d472e57c16fec45a0528ba11ce4bbdd1f8ed93abd1ee7a93227a2d0cf1ba2b +size 578308 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..5980edf434f4a5771a13c0480961e886d74e2141 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:523c909ef2f26d0320f9a6f3211db63fd4db7f8db97c90a7a5669414abed8ddd +size 510597 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..26dfbabda5f383009ab01dea89b9c7016910918e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40249101e94591b1569ae695bfcc71acb0d689863b09fb69faaf82631408b4ef +size 563888 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..eed5ee64ac13d27489d710164f1326cf56cb016e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:446edc04e1cd61566f3ff19f3a0077b750c64900f0ad7e1ab07f7e641bff6f9a +size 501432 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..1730e397bae63d52d85272823f2a6e27cd93e41d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d438e4953eed31797dd79badd2156d0ead97f655c7d679a004496e22ca94f119 +size 567496 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..66e63beb73cddca7b1f97d47a0b86bcc23d7b195 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c61c473776c8bb5a75d372fd2766cda459ac1efa60d7f3175e4d8e4d40958ab +size 488590 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..e7226af72abd047c1d23720f99b04f88a74a6722 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/8332-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be85b3fcb7baaa589fd4dabdc4f9558e3facf8072e317c983cc76b41f4382759 +size 565804 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..a64c11baf33f6d776db73114820f423e859dc496 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4815e4e6128fd7b94186c8cdce58e36dc9e8baef5fe0cb43fc5f5bcf4ed85657 +size 516434 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..e08b248a218882f1c0daa2bbc7056dce85859b42 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:701d969e729704900090613a4479f44dde2dd364430499fe599b12c1f853f06a +size 584010 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..2611fd81ac6d87ffa6d2975f813025302784b681 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3cf8a3c286c65e5c3ac198196309dfc80c0ae979aeb4217ca7d6c96ca5a80d9d +size 502689 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..f2cc6ca4a0d3045fc950b2f717e4afc64b365868 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ef70e383c8448ceba3471323fb38a916c745cabc0ecc9da3d5757e0847186b1 +size 586529 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..429e0f1086143cf904513305b3d95ed8c86a8817 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1652617f0e995b3d38f8fae8c02fa10008344da74e6fdf52cbf1bb0bec4eff9e +size 489727 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..b612f93ed1b6802ec045156324dffdfc13d44040 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/83320-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2c5e97333d891909a552c2f1ccad2ff6440b16b07008ce2aa31700787d5c676 +size 580317 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..02f57d7ec12fc3194b34fbeaf214dfb57f37617d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f8c1b5050b75c60dd660f4c5e09c39b0a553af74f78baffd9c36ee7334d2b9d +size 515792 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..55bbb3ef2d987dcc93a5ecde20e2c3836adbe0ad --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f12f9c8bfe4749e551ec2727ab151ebd7f3032a05516c61c53b2dd7d2913137c +size 584333 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..dae629e566894f41146c15a85ac3ca1f58b004a9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:87bb8c05f3e7205553b08b27cb90ec331de471a761c4bcd48295e34887f38a96 +size 502740 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..d12d05d8d8bf0c0ff81c3f7e60a7b85b8f84840e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:424bc76ba8f1c082b8c7599cf5c3820f9661fe5786ad8be178fa2256c5ee2de4 +size 587432 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..722f3498d324a8ac8c5174434cf751cfe9bcc3ca --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac868a86e08472f66ccd9032e5e64d2fbff5817a0187cfa5731049b8ffa8432b +size 490112 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..fb9698405574316d384fc4b85c62d92c50c64ad4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/87486-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4bbadf36b8313c5f0a0fd81491fc66860ddb5643f4c272cd8e0b15eba5c09687 +size 580010 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..1c05b1ba3493fb780c51fcb222b560be3fcd3cb9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bcbe0413c6d81da80c4238e8d6153f6578583499b46b9357055a1c9fcdc0a954 +size 514910 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..6571f922d567c220bb67a346b059dacc874d35fe --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f11e97a9dfad8e35ef4616a32777a83f28f66ab10d28d2bd627b41a9f9a406d +size 581804 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..d240c351ab885e7865dfa0659247478a5ee3c171 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a7ba2bc17008482d0808d6086203be7888005a7feaa2ca5643bf99c0f85c7da +size 501809 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..d378abc39df2832a3f2e51ce6778f5fcc0d8f063 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58215db5ae354199e8121ae47eb07d3db055f87631a3a4cc32caa36520b0a905 +size 585962 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..44340dae72579bb3b93d77fe2b219067d6c359c3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0471ba5a5b5effce53d1b3f9ffb0f2409f5b7f6bb24c36177c50f286ffc1fcb8 +size 488628 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..379d9e3181f92f2d7ad38a37eff5ff387fc997eb --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/91652-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d4ab31f5dab16cc74cdaedd7e33284309eb47ad14c39658a932b6c50285c55c +size 578249 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..9324f7a33fa5d59b84698bd028abe94d81f9b5d0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ace704305fe020f9f5718100787e3e9da1b4d4fe5dde238bd20591de56712f3 +size 514324 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..6a2ce2dc1a7fb6819cadd2d6f93b29b1d9363ece --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a5fbef1043d73d2f50834bc43b1257abca46c9988348a26a7c1dfab5611f6de1 +size 581718 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..73b9e6424ecb962613b730dd925c69b472c3011e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4856acdd8ec5ab01c62bdd7bbfa906c8bb17e8db6269c7fa1617fa7ace050052 +size 501256 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..b8a603230b587d77b7248ea2623d658e47b7b056 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e1b42da08fc7b5b604051fb50e987717dba218664574b590050e2d4ff8a5b8f6 +size 586035 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f5c81922011f436672cfac3bb07bb91ed933a445 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c7110a8a151f70dccedfd9bcab2d3eef5c264e278dc4f54a7bd356b8cf2e8d3 +size 488412 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..768ceb0e3023a77ed3c24074d5e83ed8eecb1549 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/95818-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a39c061eb5945e8f1e96bedbabf7602254f698f9b3efff152ed13ca6070f3740 +size 578045 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..1c75d4a204cd05776105a6129d851b817520ebed --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4496d12924323e2ee45c27d0804d597a6d039035af7b07606abb33674375b316 +size 517329 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..befb4f9b17a7d68aeb57f3a4e44ce7b621d796a4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbc09f167ebb3ab432ffa860a7305ae1148d970e084260f671e633b15e850dc8 +size 582460 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..2c69630ee05b417fd5f6a7a3ba1085bcae250cbf --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d06ef4d286493287571672ed1fdf27731fd2986fc135d55adec98d1d1adaeb12 +size 502491 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..be7647bf9205d3f4995b48a03fb49c6c66ffb43e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b8cd78c5a46beb7246c86a9f6230adef9dd36f30084b0e57075985055141823 +size 585886 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..37c48c8f0139c1bfd440a52370c3f94c4e6e1a37 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b694363458690d274d33f8d1c0cafdf6722f4fc3bc9dde8222ec586955e532a4 +size 490728 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..8b936124792f1b108e1d7ddda40bb8df6146d105 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_4X/99984-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00ba6ecf54e07c6bd7c406e7d7495b43402e57035bfadf03dfc8a69636139f8e +size 579651 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/000_kmask.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/000_kmask.png new file mode 100644 index 0000000000000000000000000000000000000000..463072c766cefad62602e5a0afc7ca7bd8cfde08 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/000_kmask.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..cd846d4c27c9961ab52fcb08d172c26f8248346b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a13308843775036bf551b5387fcab02b7b66262964036351ca778288e23664ab +size 463084 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..9e4f63a85127d3cd64097753b4050f633d4a1415 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f5f1bdb3c35d78b0d8e2ccf0397c01f8b8d70e70a9fdf9d9cc190965038fff1 +size 438552 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..15dac21f5d38d5cd40b2093546e1755ef967b054 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f04e368496da65ff9a30d8ff1edbbf5ba2391df19610d397c05d681efea082a +size 452237 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..45f44fca56aea30c92c61016338b859020e70afe --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97ddcccdda1d6f5fd9129a8cfee558df9de6f4b374acff17b9a36169ebc91368 +size 445394 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..16819e2c35839ad6c898d3e76efeef59ecfd4923 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3a814608df6d51237a8b77c3cb2773beb59361e3c8ececc7fe1e7741809c38dd +size 445047 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..d207f70075401271a5a753e87a54b378e23a3286 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/104150-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0747a8f81ce73fb1440432cb11d2331334b5f3f17a61b38ccc6db0da153597df +size 431912 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..4642dd6f838c624eec56f0da3e9222220e440fc7 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a76408002a270c5af96c6a906ad702112a19314cf36aaf3b93e1ea7c39387eb6 +size 463933 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..8e19fdcf51f6ef059320ee29c21cd85e4b5dac97 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a011c5ec902fc362959135c495c4e5aa56837fb3874288c7ab51737d3b5d35c +size 440664 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..b301c884763e1afdbad50c66c35fa82ccb4b6828 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b6d4ed75c149e7e5e05ba55be2471974dd3361736ebe9dbaf6ef5a8aebe81ad9 +size 453362 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..7f80a9308fff09235153abdbbeaf64733573b38a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9f263ebc3896c61984372e614230a96da9686f72e00794b563a45fbdf61a17b +size 443990 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..65624a94490e38df3764a9d58d713595ee6f62ca --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7aca7900043f975f468856c4e6e1d96acc3f05198d7da03d88d87b360867666e +size 445163 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..ee178300e9e86171aa403043baf2b14be5532002 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/108316-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f31fb4f8a071673761c9bedcecf7060bc129753227cace771818cb0c6654632 +size 429790 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..bc22e733dba9ba9bebef454796e06aa890e61f1d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0fb848a841cbfeb0ba169ba68100a9612c905c94bfacbb1762f867ba349563f +size 463497 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..788bfb66e94d80ac5849f436e52f34e4111130bd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c06d57c087112fdc6ae135c37666cd73a3eab917c3440ff543303e931da4d52 +size 440325 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f94784263204ae06ceb0aaa2d6792126d1cc9709 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb515f00cee0e6177927dec46a228f42cfa00236d0c38f517039a75f267c12be +size 452733 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..6759b2c1feb7a09d255a575384b5feb07b4e4302 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4899ec052d462188ba28a82fa6aa58bc68129d07b39c3ee4d999c8f08acd8560 +size 445275 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..e7f0b03906dac574b488516e0df321b2e6f08c42 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5690f60fe54c664aeabf14067482bbe28850dd17cb7fc3dc7e84fa85c5bd1fe1 +size 445103 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..e4a01617da1e0eafc56e1cf3663a2e662ea91c40 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/112482-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72a55de3b55a8567f5ba49ba424db7feecc5820998b35fbd276d5c452aafeda1 +size 431864 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..e5f1cb3bb0e128bfba695bba91ce0c071c1769a8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97ff895052c9de83ca7ae5ff97649864b803fe31472a6295279257d82437669d +size 463940 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..40a5d434cbcaf278c657e71853460cbed54d5a39 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73bbaf3634b8cacfcc1a3dc031b575e2d59e6912e061df04482832004a3b3c27 +size 440366 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..787365cf51e7be5690df2b69a90b5bdb731aa02c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8f41e4f8d4be4de0b04d0370ddb9292fe382ad814bea9235c199ce09ffbd95d +size 453105 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..de979cbf3279cf93f03f951c097747e1749d539b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:812d7838261f3f1a9086159713db24fe888d62b457d9076ea83580f14af03c97 +size 448537 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f3f6c9ed7a33d5ce88fcb72578a923a82c73727f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b736933551ff30b76310e194644245852f914c666eefc282ad80a33834082ec +size 445604 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..435152732fe13abe2ef394b98b267df08991c402 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/116648-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c1bdd2594f5b7acefcdc3034bffb25be8cafaef0f0c716bffbf084a37f1949c +size 431953 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..ff665f78377a799202678c09b9d9a52b8a708357 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c00e45b0187f783ac4bb4ae5cde183bb9715c693ac866644bcf1d727355f55b +size 462992 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..0c4feb9b32cc64b20c907e81123c50b90ca187fc --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd3ac7f0e379d84a4d6e8f3fefd24ab230c3940e6ba05c38bccdb5d425e4d90f +size 442996 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..d23df8b1be4fd6a9f6738bc3620c32ed73aec0cd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f6c76825f03827fb6d35208cb175b2aa9701ae1e935800f345e7595c651a4dba +size 452662 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..e6ceb961113d23bdf7445ff7499bf66241eb9fec --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37b70aa1047dca36ad2cf9f68258f3342a639c9993f32400d7d2e5426a25e071 +size 446353 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..05343d5cb7dd6ce108a3e2c48c8e92d0c8bd9b00 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74f7a88cb521ca00663350c76bd8185d7acaedf9386d8a731d4ee997044c13a4 +size 445031 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..9efd430f08df1c5a565ff86aaa61b4e7f39a4258 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/120814-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18db88670c05a0e3f44c24090714d7177279292cc54d1628ecacf67dbe9e8de8 +size 433542 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..9b4b78124932d0e10c376f7a131eb219978e1564 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d2161bc1190394fdfdd9c6868ea782f0d0699f9256bd7d61522ed5b87ffe643 +size 461543 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..21c3c92980acadc6991e664a05fa1ca1e9f0a574 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09db2477e2033e4012bccf06bd4c13e02328f7ca7d1a32265d020e564211c936 +size 439736 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..012f3cbece63f9c4f1987943d076e5980f1ed6eb --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:11eb4a77de408bb3e8fb2adc41bd0d28c59c102e298ba35ebed88775611ad4b5 +size 452169 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..7934963100d9a16e1597b64cddb9bca5d276d6e4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55a020b2f4d28bfa73533735ec115080c9dcc46c723e50c19e3dd11ee477916a +size 443238 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..b5fe789cfadf3cfb15804fd996d42975d0d13582 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4cdb9013aae7c80c75882b4632a679d030c07f342650e5ee7b540b3d4d9acd62 +size 444905 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..c31d4a969c5a23a60aea112e05f77bcea19ccdf5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/12498-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf16e686627589a6e80629012e220255721fb1ff39fafdf6dc2ce15dfff2898b +size 450580 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..0eae14dddbeced4aa4214b0355efece3eaecb2eb --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c1ec100bf33eb000173bd0facff8f57ee228c72fc59103f8f849822c8eb28ef +size 463893 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..af4ea7843fa20c44582ca876016b2d5f6ed663a5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:781f69c090916cd68fc38e980fed8a45837dbd16c4d0360d5421f77189e15e3b +size 443034 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..0ca10efc08aae30e160b4ab78b6494592e42e436 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d80829ab3b92b900bcf8c607a0c862a4a7f1dc6021d6f0626ed49342d3fb922 +size 452668 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..733fe00536a982012d0c8642b8a808c693ff4d00 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d47164a0470327618e33d64d0feaf9744dc7efb3ad5c647c2f4fef7beb8c314 +size 446325 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f6bcf307e47f84f8ba59d8f6ac34f9c52bc6f5a1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd93c465ddfec32a84d01ff02309ed50ae7081ac2fe6c191df801877c45a765b +size 444538 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..fb49e684ce2c20989c5a8752d61a3b15ce9feaf9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/124980-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8a9b89aef1a82a766756e870dccd573d67a2f21a8325f746521a45cca018181 +size 431848 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..1a35f2f2a12509ad165b250e74ba326d4912e05b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d633c7a0ee31b7a6b473a70c5baf6e2d5bc464eabd3d46688ad98b55debfc27e +size 463981 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..af3912d600b398fa0fdd702a2307039b82307af5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:848b634674c8e3e3842c8318773c38649b1a69a229a442b0c39459eb43847d22 +size 443288 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..c45603a52eb4b9c4f26c11f223c6b23339ff827c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df543b95ff4c46421d198a3910b5134503b15ae5da4222b7506ce8452ed4f5ef +size 452917 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..2dafc73cbf685d7dde0428539ca802a62993e431 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:05492685b51d8cf3d9238b812766338d484494a053dbab42a5a0e500a9fdf83b +size 446805 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..ada3f95e7ac9137dc327b520863d549a24233b60 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b415eeb4909491657b45d061fbbde04b0d2bac8a43a2be250197073ed8c2646 +size 444749 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..f0c780e32db414788d74caa62d7e2b28ef6e0677 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/129146-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0cbb76a2d1d2b63a6531745968b3f4cb0b63d429d97c03c8cacdd949977d2687 +size 433261 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..8df2ab6e79f3f53137c8a150047c53dfd4111bab --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb5b820febff02a255cc8092b20462ad0ceab81589b6f4a792b01893b118eec3 +size 463851 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..8cb02a6ee547da97af8605635c359b7f7126536f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c184792de9d7199cee2b59bce40c52e4b8af928f0510d7df621a903cd65b839 +size 438725 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..dc01ef50b924852705f10a847d5f3470d3f091dd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5dfbf6110548465d4509e7fbd992c143a2fb30d8ed901204f36f2a4d1c9be877 +size 453071 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..4e805ec1888f770e4513e8817e7b2a4f2f64d4cc --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:909635937445b56422dccec0ecc4d08afb5a3372fb54a405702be2c8f7d89744 +size 445708 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..74f651cd08be5e1c8348458f6f3d624474c8ffde --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5380fed007e4860ebb3035e4a5952b10cd2f91e4f4c9bcdbab70044b7e9d0776 +size 444952 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..d37d9e8295d98528791d01e4717686baf9b1f66d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/133312-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd1a1455a2b7e95ba1fd377ac476a483ee6b0ff38cc378bb6c7b576f5da1984a +size 430774 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..963d92cc1e26b9491c7c94c49b38db9a9e513832 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4928ea48575c4683e5d1969ab8feb3bd94a12076d796f00f609a57ab94e4426d +size 463362 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..3a8774552dd0ac8f8b02165184f122fc64a08734 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65229c0d4e552d6ebc03b08b1a7a4577c7c26d1e575bd6f7ec098af7a0dc5017 +size 442019 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..d1379822945b11401dbfdc9a5006c384dab714f9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3847775d0a94b3a6ccca33f996b8eab7bc6d76fbf8e257ab03b4cbcd3e02c2b +size 452573 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..ff4fa0560a94804733f3207dca3740796bd98051 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d56cdc68558bd130f49bf3d9000bd2ddaab10935a72c7a9582e820ba7014950 +size 445981 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..417a813344b1b9d3cd877f77b784ed35cdd9da38 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c50d11fda8a51ccd756ea8e47e8092716c6bcbe208a47672e724ca365cf5dfad +size 444777 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..ae9c24365c58c982ef3ca2b47edbbd2eced4ff79 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/137478-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:509b6c8ad095a01f536857e02008fc2aeb72e4444399cbd5df0e3aa9156d4206 +size 431863 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..bef8c485d36dd64cbff34e8e8f854c3a352a24f8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6005f8acc0312549827c61b9a7ce34ff3a5937addb83833f00762ca707430e33 +size 463479 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..539d369981a41587265e19488ec6ab9f88d17fb0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91cb5c22aec580f543ea0eb2d38acf3ba96c019c36f01157880cf9612791ff5e +size 439820 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..d3fedd84ed7b32f742d60898529e3815c14fdeb0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1ba27ebef14c8008eb9c47f29af1501cb6a893e8d6b5babed4c25b3f17d5c04 +size 452772 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..88c6ca018ed5037d5363369ac942ea7f5a0b0370 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:378ac2f7c9e62adbb99f175e96a7ba244cae1a8225d8bfc6a34197f158f4ae0f +size 447204 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..91515bb47c7a0941379ede30a42bd32b73143a10 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8fc775cb88a59ab07a6231dc9a06a7d823f636db0c5fa8a74b6e04c2cfaf7ad9 +size 444830 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..13175c8bbcce33d9a9c8d9a4d72352c27e57b43e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/141644-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3953e6d763f103518a43ee3953b229590a1a6e047495b8f6246563069ab16442 +size 432316 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..7481a3696d17b77f76c16e3b33806cae7f086d13 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e929c17931a4578d712da914446835804036266fcf8381fda89903cae1c126ca +size 463083 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..95e3916780e0fe57b0c5c9eebe274d16772a0252 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4696a7a56724b2f6954427fa5c51cdcb190f203f7a88eae41c2b5be0f687b595 +size 442531 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..ce42aae65f3761c74726e9cad52e72d0726ff95b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f9ee7a39b1c3a8e1a0da18a95c1290b6461b7d8505bfa6e056e47269575852f +size 452339 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..c60791cff2cd6ec16670533c923a6345472ac92c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d8f8d222299d40147c9b544606632e11e1b0910b41786b46a2800c1369470fe +size 445317 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..38c30903ec118c28780e5fc0cc29b274a7a3977a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3c0768825cc0c75574cfb31c76b222a3bf037ca3a8d0cdbc891afde8b26e61f +size 444962 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..d4e44fd3d31b9daf65659c02e81e2996dd1d230e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/145810-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01f9a4ef3cd247b40ce9a7b434ca0f19d689cbfce00e2435421e14b481256dd9 +size 432625 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..517616e01b70436667e29cd1c2b0eb4ae5964a78 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6cd965e402d876fb461fdfec64333152855f4bf4a20559f53484d0ff7871e689 +size 463919 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..0fc317cce1d289592802db88470ee5d8fa1e5ca1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8abca803149d981ed1e22e2db0c2bf8a8f16cf02f2dd9578763a3ead49c910f7 +size 440913 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..2dfd80e9e60ed66798540355dcd52cac7d57438a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e08a11cb81c280a3434b8afcf85da15ba443736e5a36b495f850f2b71fe9ec9a +size 453161 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..d1ac8008342343771034855954310e69375a3b64 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:891b347fa9fc8c1e0ebbccecb7c55d57b3ee790630baa1f34424b201e75e75f4 +size 446300 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..1d6362a95b7994bf40c419198223e6b466904abc --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebb569b12fca9049cb8f9ae54626028a3b7e9bde5f692a4897faaa272b54e710 +size 444735 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..c1b49ac8c94191d81b9d19860f5390af617b6b6d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/149976-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf03ebf3bf575ada21a2b143482d032db6aeab225f17fe38b47acde28ca2f758 +size 433155 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..95b2f9c9b2df26a66d1cd8bd1ce5ee236eee29b8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86585da8fac3811c5dea67879f3d74520c0269445adc8aebf1d7a323ff7e318d +size 463213 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..c53c1820623b418e4fc97fcb76e1d6797c55ec03 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90afffbd04af67cbea960ee0c8ba1b81e57f8ac582787bbfbae4b602d6a78b39 +size 440206 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..5fe1a7faa184db5ca0ba22ddffe810871fc19dd3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eab4593a506ca2503921ad1a9753acac292bb768577a4a7aa397f49ca8fb4aa8 +size 452587 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..af6197ec1f34db42b4b74faeb7dd4efa318ea986 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a259c9172e5c3e4bbcf63c2a65d9449b1ec1a3e5017b0b58f23602f4d4e22feb +size 445927 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..6a1cb713d4406b068e65c52c086507d87a817953 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ef5bff0930532a2468231199857e9b3bcebf60a25e15a4d3652044d9dce3855 +size 444759 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..f02de265d801ad6087874ef40f75128b59ef9da4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/154142-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:624455647fd09dd4f9e8c89f00c3454819dcf996b8fb5b27131f1916e2ad3fda +size 431034 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..e0d9a457c516ede7ce990ae257f0d985b7e668dd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f907368db714e8cac97a653deecab8bf14d0195b441be899d5839d448c39b834 +size 463445 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..190d9059e0524aee62c505e55b9e9717b485f685 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f4ba8bebc6ac6f23b6e7a8dd287f6c4dcf11b8eaab353060f8c95b064da1169 +size 440550 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..3914ccda483e81cc7a69e7061126e5d7ad199dda --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91d3da8ef5281d706e76091e4db1591bfe238c6fdd32a6d75e6b973736f00118 +size 452738 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..994f5278c0b3b38b07bfa766017093ab2fa3d63c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18f09ee8ecb376e1bae4cd0e0d3f05e79ed6ec423f0987fe8e3978878e375a88 +size 446331 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..ff2b1ce5dd873b7ab22a85c8486e67aaf105cb8b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26dfa618d5b44e5a89e8c71eba908210106db84416c663f663324abd9d6c6e9f +size 444708 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..ef67c56680cdb3bf2c5017ea8da8048b03fbd455 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/158308-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b073e66b190230f82ed1430138813069533fab717ae5e65a53f076d1b290c02 +size 431801 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..af3227505696f307bc338745776016175e8980c6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c7bbdc9cd6289cb30c20a9e9408a89f153c2dad016caeaa1db55fdeafda0780 +size 463973 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..81658eaed59dd7fe8f7ec0d6d076772ae094cdd6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df8860c6954e1de725cadc679a3518e57bf194a9855aaa1fa34504150fd1e66d +size 440161 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..e23694301f82ead975c29db027da79235595b1ad --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a898ddaf8ab69698b5a29a435fcfbe1b2b7ac3053c8201a66505ef96ccb20280 +size 452625 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..5003b50a5bfa357c93979970227f3330dc3371fa --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:403a0cfec6f5ca17f924b44bcc91fd77595d4eb3ff4ea299b0e1b7097afc668d +size 445346 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..d193305f555157fd89b3179b83312a84bbad44e2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:963977dd43ff2e7c0f5c63d574f3dd866b929a4c74e8793c0316c2a6a56a74e4 +size 444455 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..65bab94614a78257e82dd6c731b500a917849124 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/162474-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9280c2477cfd0dd193c56bea8c608c16840131e9dc10b84106df7637a2c4e73 +size 432373 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..7567046532336467f71e7d24e9cdbcdb3b9eca12 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4c050399f129b2551bcd90d076efabee705f58d57bc9bba14208a180328a91a +size 464743 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..58778ab290bbb1f008037c5df5271b989d6516b4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8fd1c09b911b563f3fe6a6a8287bd4ff37f2260808c833086b0e9b9e2cc51c33 +size 458158 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..d985e0f34c56c716216ec75041d3a0eb2a6de3a9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:931790f9e9aab4547f0b952f0d6d967930a402a46ae2ec12aba0cf4acc2c256f +size 453008 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..b84c5e19f6e5217ca1087f79e63a61bd6dc6062f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9dd84d7e4394083d639f1a0cf34dfe8d3f0d08f3c8adc726f2e61a320511604b +size 442032 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..e72db06a7e94a2b5c8a976d885c424f53e2dae2d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:217eedf4ea729b75d87ba1fa1706c1e0ddd9fccc1e9f261c839a1f9474cae438 +size 445005 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..905aa90f2ccffa21575640b648c3a72ae6d6df22 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/16664-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91c17dbec80f18bdc00d8900ae5af601b37fe8e21c599e14ea7c2a2b82c9e4fd +size 438530 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..2691887682acf70c4deb18c562ed31fdf20bd0f3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9835e7a9b02d2365b06dc985def5c58d0db9c0210ac684324c7ea528cd1b716f +size 463330 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..83d5b1e2d3ddccf025eb6d53655c65b888e62029 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:23984dd6e018b31663ef5e63fff6139d12135c5985dd722fc385bce6c85a5a2e +size 440571 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..528c9b223ea745765a4b3cc600c866f47ffd2898 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:840055891e1995c069abe55889e68e5cb2442c443785b8524757e2a5d02360e5 +size 452248 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..c6c8d45ee346a4d3b707b5a809195a9c44cffb66 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb7236cd2675aad4a01ebbf27149a26840d7d188ca4747ec1a65634e3a88fde1 +size 446084 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..582526d6986ed7c11e4bc7b6d23eb574ba8a37bb --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b2a939cde08e5630d0b9038285af229c7742e0181ee96c13a090fc428b4b7fb +size 444481 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..06cbdedc5b0f54024572108d593ff54f0cff4bd5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/166640-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71b64109e864059facd67c8fd7aefe24d32df9fafc1841b4b2da461f445761f9 +size 432692 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..da6bd925c644767214712a73d76cec35b6c27303 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1475b005e86020da508c98994bcb99c9876a15979d06d397567eb67c02b2a9b1 +size 465940 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..a0a2ca2b9ef5bc13d9457a6e4e218e3ae059b2cc --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:20237d90b9f677ab470d599326ad88cae61d725c3e53d7ffd2bce0983641bcbd +size 448304 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..fad149d50306c7527fdbd4911318e0405f27e433 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e49c342c5637d914ad656d064a6689015e192b0d9fa6a1b055dcb9cb629b8954 +size 454400 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..7d5f1deeed4a6c3ea5ebbf713c0d13a11929423c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c19888cbb2a60c8d8985a9c3859163261103bcc86755b8ab59d81d42c9fb6cf +size 441904 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..1a0a059d1c3b2acdcebdbd50da40ad254de68f04 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09737f97f1be76c53d7cc2e4d56a375234ad5a10018f1262bf047455919fac5f +size 446549 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..6da4cea2f7afb0d1effef6b7dede3c1889db65b7 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/20830-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e917f59143c2904c38b220bd91eef2c0a8342c846883c079460514b959ad277a +size 438115 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..0111b3ecd8671c3037c21c1c38151f2f45c0cc03 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c96b0203eb7e67b69cb55d5c22000832a5875f622980508fd9eb41825457556 +size 465508 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..83cbc4ab64d779c1cf20c8a63747c8f9d5361fbc --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a23d5d9b2bb58d5d699f2bb355c21f1d8f983c982847fd41bce653cb3721986 +size 444634 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..b85b9c55ac3fdc320e0d7d37323b59e84db24584 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7525d04ebbb84f3650c444a9422e4d563168a4ae834ec06c1c6c699d4ca23f9 +size 453914 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..d90458fe58f64305cb20b228772b534911e5abc6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be94286b3a013a4b62c97190fe530f4189edd2a9c729b2aa291e0e9b8cccda23 +size 447821 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..4498afc786a5393ecdc795a3ba34bc3764495c48 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70b83004236c9038e78c4caab239780119445250043b4efae0dbabaf12009c3a +size 445686 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..f32a7e0cae3847ac8ab49bba0ffaae63ce5f1963 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/24996-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:484a26f0e19cf798e71c6b3d189b12264b2d24b9369d8aa1d9c9be1067e66dd2 +size 439293 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..16fde9afc166e8f5e351eab8f788505504436ce7 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8e9ec2d9f2d38c4792068d517e00bac0130c3071f4452a1fdfd70ef98082bbda +size 469005 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..7531fd63d250d04231ca0a99d7f011fa61b8dbc0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40db7c99bdb1427a0449801f7c3ef4b23a32065e22bed0bc15049394d6d586f1 +size 456145 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..0725c09078696710210f0cae7ed1e71882874393 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:213ed4bb30dc6f75eb2c8214794900359c9192de03bff283a1bfff6f7713f00e +size 456715 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..0a3a17bafc2f31d52d49857688d7d1d9eb57e87b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c26ef59af73774050ea23db0189a9234cefc3359127a1a43e3d288e12d042da6 +size 458325 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..67769a5c4c3696d14f9d56025c8466f9bd89f5b8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bca892bbebf068cc090168f9a5efdcde831b6eb40c77f24e23a8ac5cdb2e4575 +size 447874 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..38a921b519097ef3748e5b37a2e18cdadc835ce1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/29162-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9c2f1cb1c000519504dffc3834321447444cafb4f8f1c1cde68c1d48813dca1 +size 450033 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f7c3f19d687dc52de5449980f38b538edac401d6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:103eb2e5c7587051843e841ff98d23450b2da6d2896dbb23c721a557f671d052 +size 465249 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..26bfe3f0770d9e2f9400c4f6bbaff48de79ff0f9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b773d2b421290380d759bf2ac42a4705d1ec5d2b350247ac3fc21632666bd2aa +size 424926 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..182be18c485490bd004b9b7faa253b51f8fa1163 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad2d0d577c1f117af865779b285b27006de25704278f5a04cdd496bb707ce735 +size 453016 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..d1a59060aa510ae60662b83d97a0fda2d08695b5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6be7bb514f9abb503f56241bdafebe04e1b87e7ff61a83777ca3b28812a33dac +size 441939 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..3353b7cde0c074b4bec22029cf4403f1fe784990 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7144f7d47fab4a1b8a169e3b69ed23fca503749ad3cb96ddc1264c55374aed92 +size 445121 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..fdf961eed1ff7108646d27215fe6eff60a32d8d1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/33328-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb8fe91e7836e315175487787cb948e5e6383b3f7e71c5883079f02f07a4b4e1 +size 428148 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..234850b1aeccc46ccb667ef7bb89a4f32033571f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:080476e267aeff58b9a411d0112d8e2f94b0af1e9a26e4ca93b96f67194c203b +size 465588 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..8a47b13b59c50fcffce4a8239f948fec9f412cdb --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e46516713a8415329eed5fd3ba9be8f464ecfa08760224737c2a70d0dd8170bd +size 444963 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..4db79512be1d1a3ec0b178a0c0c88af0a8bbd5c6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f2fb068c208e8e88a496433e0a29ef3013938e43a4174351bc0567afd580827 +size 454479 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..ff2318f1baab3ce0f782b3542f29af5a81c6b3c5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56083dd73640585f102e42ed26beac80d022685d8660f45bd48869465220746f +size 453474 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..75353c197363771472295fc1a6d95f62008eecee --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9e071df7b4ad3caa03309b8725333056665f902ed8f980773be998a8135b099 +size 446373 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..25126bb130938c76ca80f2c22ccddb2655fab652 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/37494-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6fb318e61650825554f59cc77f2f10e333518b545494dbb9c9f800c8bb8fcf0d +size 437271 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..6af798b6fc7b426558e768b8cdef22b3a5815a92 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29535241d9fc0641b2f99c2888dc55498cdb5f0da32aa34829619dd4ad7642fa +size 447812 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..77dbc4001c21feb143b3eadc5de272a3da874288 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa3c6aa0ec0c1413acd5b4de520b2771f2f78716ff9db51bac243f861f8a3752 +size 413067 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..b0aa9b7607b352df8057eb6aaa5947149415f138 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c6a1318087a4c0966cbbdcfaa3d0a76ecb701b3f605a3efb02c91d614376eed +size 439124 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..43e44923bda3342ab6756547b1b8fe82533292ca --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b72c656e6ff16ff74462e078fd0045b4ba9a14ca1402f9a248353b548685a6b +size 431475 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..4134e5904e36d3502878ac1a3bbe09a93bae8691 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd99c9c217cc95fb7a0df9fa4549dbd083ae478cca1dc3e51c20c056ba6c5adb +size 432080 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..2273187f174e87dc5e3022bceed16d72c8367a33 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/4166-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6952c625ae7d4fe1131254833d7bcd2b0a28c137310894cae987ef77a0d06907 +size 447003 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..ef69c555a13fd4c1152f28a28be26029795e6e53 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4349858a8d12919a523411cf7e38a620b02b11935ba483abf2f1954fa03ceb0f +size 464682 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..845883c64db4e7b9f66fe90763325a426920d9dc --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:099de164b261c4d774c31ca5a696dd97801ce1f79287ff329c87e137775da088 +size 439605 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..7778838230298f49495de5f723aa9069edaa391a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a7d49a4a99681ba5070eb4e9bca168d1cbb9bc2aa51579e2ef9eb191dc9b2407 +size 454229 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..420affe304c1785d025923d276deca2c5749a729 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:642cd3851f2eb92c172c8792e27c1259a2aab9d389a7ce9ce42fa922ed161201 +size 454781 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f4d70d8b7635c97afc5cec0340d53f5bf4a6a57d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:572efae95fdf9745d0cf2bc26e875b03b894874992a0dde2f0d93f66c666187f +size 445985 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..d3970ea5dfc8e00b0b5eab582d258fd2fa7867c7 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/41660-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a2ca5f1239bd27d48fb322fb2a07868f8f8ebc11281ef4205928b94ee64e3f1a +size 436605 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..11cd677e2ceb03e04ea58948a9bef2d3cb59a5b2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa836e75963001aa83a0a76c45da0e6b6158f7eb5f9b8b4b0219d29b321e10b6 +size 465574 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..afe35689c393fdefa2021660f1641de083ee1ca2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6fbbe290323f9ba92c574a613bd6f61c7f0c023209e2636269222275343e2ac3 +size 448914 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..de3de63b2c9d950cd1576bc6732f890f2c1bd5e0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7b5c194db5329a171bc159d874292b60a4a82d8103659e3dc98ab32d2adc5f2 +size 455101 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..21cc107c349eb0a4f704554cfbbb52ac33e78238 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6cefd459ee78174649d708591b751e5c99f2f845e2e4a1cb8ef4c97f0275c841 +size 453661 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f1643166af428d4dbe3955a0a0d7dec666283a03 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a0e12f90f219f971e6c6998023f7749290a8e1077e48199f7dc48dd4b2d83ccd +size 446055 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..29cb8176eec1baa1b5156226b5838d4290ce092a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/45826-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e39c66fd698389b49448a91889cfdf77b4b9a29157d049c56164e8d6f3198b22 +size 442650 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..64b8376fd84d7cd54fe1f3085ed812d04457ebc9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72749133f8947ace6c307fde0f64588b77988c7bf47b5b773e6dc030d5323fb7 +size 463923 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..9206cc8676b4c9713e2806287345b9a712c6ab47 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd0cb07984c71d3c13e33d51964df14afa243f0251d7fb30e845c73ba0a41001 +size 440171 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..4344452f3460fe5e4833271d20b8394d14a4fa91 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1902090e57121b0485d97c6e0c854117a7a6570f4585c302cb18b1afaed687dd +size 454010 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..ecac34e243ed56daa571a50b70124456c19cbd43 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86a71e5b2b38619fbc43b967a87dd4b754bb6fb51a52a65c124cea085ec30a6d +size 448778 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..7cba8eaaf59cc927018acf927348231055b4202a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:806d53bc15b6244bd04b94d97a94eeaeed4874bc1c114cc60213dcec6581a5dc +size 446098 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..be16ecc3df7fb37c9654efd2476af44bd12fc7e8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/49992-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0091b3a6f2560dcdb7c6cd1c82aea8355abf68394496505647eee5b4adbea3b5 +size 440877 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..9e28a45c5d2cd5b40fd6f8e06efd4265f0a24b4c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:132b84f2cf431eb3673d3f8756e586ad3da1d50117d24e6303fd059770099674 +size 462447 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..40e8b460fe37f339ced4fcb5ce948ce24b021d10 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba27bbef1dfbc378c813b7a5c661e851fcd552441b4ec082e7758dd0dbd48c6d +size 434927 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..3e7ea1952af49b9e0aa89e472992fe65611dc4d8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69194f2382279e003028df2a792f02c20deeab622bb0efc2f0d7d5918b276c7d +size 452452 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..20f77955541b68c6a09e260d3fa0a8a3313e17cd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b2f4b1b3842101bc8d1803fc5eefca07967f792802770a09fae095288c47f49 +size 447893 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..06a11f814319c9e2eca1f5767aeb78335dcb3ebf --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62fea36aea391b0abb9847a862dc73cfb3c2dc9acd9e3c73bdeb2d1cfce475b6 +size 445007 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..ad3a4ab2261629b66196c3f249938d30301f7a04 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/54158-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:807639e22684edff84d2aae179fb5532479a06e60ea19b34bee6fc7946e3e27a +size 434687 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..2dbcdb228c8078387b6267951ea59bdee75a00b8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6029845c8b9dc3a49be2fc7ff15baa261885a21430a5fa5e179aa0c00e544669 +size 465161 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..bcdc9408e4fc33c56b1d1b63c3a8a0d5e65a5834 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:362b94ab11ccea24ff6e4cb6121f225e8f7e18897cafc17af59c739bc6ec30a0 +size 440077 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..35e44961d66717725b597cf4c850d308a73464b5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:17ca231e4fafeba6ca9a51698eb1c36ef23f4c89e67628022392d7dd0c76c2f2 +size 454419 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..79894f15dea566536fb6cab13e5d27c1be109aab --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0715b920fabea02bcae33efb1570e818e5d7884cf7fd677132157968adaae5a +size 451917 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..dbdc40b442ed8300d7bc1efb0cf2e0d9c724a7b7 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b196ea9ef3688d19e624383c7c84a566732a46d9df875e9622fdf516bbf7cdbb +size 446396 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..0a35b2893301ea3dca739104a497b0ecd5d16129 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/58324-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e408d2262feda6324c7fe421eab594e6ad24e1783a9f07a91fbb0970fd4a026 +size 436048 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..4c5bb675d492e58a4a1d70aa52bb27f0263b05fc --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ffb3e2728149c4dc3c6ea5a1b74f809b6b1a651ea5b73361a2b618390c907822 +size 464039 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..3206c3a94a962978cc640e99b666400dc3cc5047 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:10fdd45f401e39234a9915474f0b3c05d8244870b68263d73f4223aba029e484 +size 442058 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..26bba4cfa45197875e4819ff1b8cc99b900af2e2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8ad5bfa261bb69d9fd9af3dbaff3b87f4db2977b2ad4d68a12ce6ddff797c028 +size 453318 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..1d3f673df0fba5619530a32587081ddcda66d5db --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7da757c6234d58e78a9be02fbb9909bd101329d314934018bb12f0e92efc3b05 +size 448521 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..38a996297dd1fb3d0a2d80b2b30dc2207bfb92cd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f5878846834c8f8d5faacbcc9ebb08b0134420c7c0fff5d48ead629e70e3ed22 +size 445254 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..571cce79585fe6185703b36d158fab1865a72e03 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/62490-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:882292ac5131764dd3ca2f07010b2624e59f37f19d761e4533cd059b12e17dfa +size 437116 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..19bff2953c7175d6043143eda379ea0e6fce61ab --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f1720a371abd7f7fa39e37d9eed8e8182a704de3289683e1460d8dd1c0279f0 +size 464354 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..55edf9fac09f54ea93b1e88b56a099db4446e5eb --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:725e85f58dce8facae264930ee75856a86570ab2f91a23789c1fab59f39dddb3 +size 447575 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f00db68b1470b64e141617b1a6ee9f768ad0d16c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:23393128ce913b2f29d691c3ea7867a884412ed4c33101f739ce14a1bf8c50b0 +size 452805 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..f47ca3adc5c1f624776cf6f12b24ffc9c94996ec --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:845c24197a0c546f605325a81de389a50bb2ecabb2472e25b9497bdcc937af07 +size 452655 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..07457e1a8c8c441cde6df5546a2438d024e36cde --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5eadcbba9034d630ea1893fc7599992efbafa6b777cc77596a4f02e11652361 +size 445427 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..b588a85362043273c4fc51e6d33ab8aa17299879 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/66656-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0875096abd49631897b91c54027f9e281b4c48a08b69325285f6a5b1258bab37 +size 432041 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..37d5228c1ad40c6478dbe99f01b3b5450db132b6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5319bdd8d1fa44652ba4fd7f6b471d95866eee9315d64e965ca80d91f2431d67 +size 463418 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..5a353c2c7a1065faf449d179330e35ea576c0269 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2500dbbe5cd2ad9533cc45303f7dfea0595f28d2d167cc89d02a7310bbdf927f +size 441093 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..db33985ce19e360f3c41bd0b7598876ef4176f51 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a09b8e3177824cc62cfaf5ad765829ca2f3956a8d04740687f7f89d1444fabb4 +size 453601 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..cad3d088a448abb04c27862c84aaf591c0d2486f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5bfaad7d09e6361d294bf501feb6e111f6f3534393af08c2ca57b6007b2d45bc +size 450369 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..f191e06e24e56a4b13573a52aa765bb9f3af5e35 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9af98630190992494d427391140d27131b5e9a812f7b813baacd9b29135eddd5 +size 445623 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..8c32cdf8e23a418d270e460e6645d73eaaf994e9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/70822-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a704a6d680b593cd8bd80697a23155663974af26cce00640c4d62f5aac8f95b7 +size 440779 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..7d7faea8fa0d72b7ac73c199cd09aaabad72c5ee --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:636f9622f757eed7a0e635b06fb023e9db0dc33927bf69d53b100630fd6fe6d6 +size 463417 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..8d84bca5b9ba6f25e87434bf608ca5f55e398ba6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0cdd31bc9f9b6de7b831589a21f5c934fafe3860ce7a02a7524bf5fb3ba78a35 +size 441038 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..050217399b32ed8f15a5802e5a3e5c37b87bedbf --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7cc4676dfa7e73728473694f1765b359e3d98988c6ba5c7a4cb5f885f9abbac6 +size 453077 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..c3e7eb3b33f5b06ee3e8e7e2a1a4904c7f2cf73e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d2c9553638edd4fd0547fcd922af9d7541c467a5adce1ccff80653ff91899a2 +size 447776 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..3961b05715a62588b47d8f79b5b1e7101b3d17a5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c6eb57a46ed68386bed60bb35c38ca81e8378991da06a7e87c71585bc394812 +size 445893 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..6dd72d5f124058c46cdb99033a582927b6c65392 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/74988-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0641201d0a846d5edc7c5896da49b34be0d9e70f8b14fb9c6adcff641eacabda +size 438993 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..dd9edcc0a1871e701c6e1097164d4580fec622e6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d108ffd493d5b66493379a40cd5c8dd5daef46a2d1703dca6016d796a735504f +size 465266 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..cb81ab6eafb6c0f97d585031f4bcbfe989aace3c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3047321de817ee9c4e7ace313fd8adf46c11e82ac85d2fa54c0cbcc6200caef9 +size 438472 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..3e7067a33e26951e7bac6ed66a5f54913e0a8542 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d4ef4a6dcf97b1510c46c0a7cec84813313ac240b8f54113e87a9792a5ce98a +size 453855 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..50e9c3ef48fc5b7b1bc6ae437cc30f200b363983 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e1b821d99f1996366719d20831e784ee268e209f0d55687305e74a0f05e15f0d +size 448628 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..a97fd4a96c60aea645391954433eb7f3c2823a6a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9cc01158bd932ebe5cc5a63307e77e28d8740e652e02e050bef8261fe0876b50 +size 446107 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..316e4146b06dd911db6a3e2aba0e0d93f59e1301 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/79154-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62cbb96765b5f10667cf16cb8d2967ec07cf28f8af434e0643ba1c9a52d531cd +size 433219 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..afe211068f929cad76bab111e459f8af1169490d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fdacac8783aabaa875e9ee5593b063dd98231b15a262b613b67c46c824471f8f +size 460980 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..96bcf500351213757c8ccdfbf1c3eb1fc26ad791 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:126342de6a168d116eac51bdf845d397142c72fe9fc4d3f0dff561074538cec1 +size 442576 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..46273984faabf3e327db99e5b6de687a992bfb2e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b828d5a5864e1fcbbcbe45e039ec19033b3f12b926ecdcafb27d62ae9d42817 +size 451212 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..f1bc7f0adda979c4192f275994cc8159e89c2440 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b3d4a97fe1a0f83fd0fb4f1547c312c1e8366c0ea3c88ca1d4a99e8806948255 +size 437166 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..4f9561b08d4e7f7109c221dace89e48f6f8ce0a1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebbab89022a8d9c0217b5e33268bb83c5f67afd7ef7d3e8840e64254a8619af9 +size 444441 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..3906196c948a57148dfd31f719cba8718a8fe9cc --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/8332-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d870fa4ac2818b0c5607675767e3781a22dea7ab03fa4dceeb2199d296e8bc01 +size 447458 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..03772d12a0c057243c60be170ab2a709b133ec79 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e890105bd4d5b4d24dc1778b4b660182405668fdc4e1b0a3ae6799047dbd4a89 +size 464128 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..b023c436e5707ed005f2b127d4fb2554727bdfae --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f3abce8ff88ab5e2667fdbf568f31a10db55e19ddeca8aeb87ccef61b6eabdb +size 443066 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..0b813843b2ffd92da8c1cbd00328b67645d48aa2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aba1dc115bd493f1f75e85ccc14db2910fef8f90bf4edc8da80a7c7d7dd41b0c +size 453164 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..a6924fdb5beac2e2352f26ee33085da2eecb1966 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ebbc9e8317961227b6486f50a5fdf835120e5fcb9dcca374b1c7608e8ac37cd +size 446185 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..44aeea0814c0c361117f1af8440b4aadba386dc1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ff37d338cdb1a1838e143c7a3f2a000b763f0d22e0481638b261e1446ad5e39 +size 445040 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..5e6ca92c8a4a675c1a04adfcaf57f733a9c91d6a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/83320-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:adef989f6b2d27318c9b42240517e2ec9ad979ffc861eb77d616dfd2f94d6852 +size 432051 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..4f14d192c0f40e6d2c9080e61201b3b0f9ebd498 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c330cd68b29715da35ba08edd78a66de9abd1cb89e1ff34c7a6b3bd6262409ea +size 463950 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..68b10ff32f2c46a2bcd4a44903f21b8fc86428a4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a15241495974b8394b779150f40f6bd798351b6838f347c6b1808e5133ef3be +size 440775 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..ecac5488466d875e1a8bfcb847c2dfe58031e319 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7bf97505e573bcff850895ae9912ddc046d9727ba83d3dfcb91de83b3774a347 +size 453376 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..86920d5394d53e511e51affbe071acc815c65ee2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00085fa429869909f511b5cfe71ba1a35a5fbd7fa3eefff76b55ad2862b3b8fa +size 447984 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..20ebe3fc508af57815d2ecd94aba8e36c18b4ff9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:645ea760a3a42686e2519baffb4725b232989838b951f2b3c60f3bc0e7bc9a85 +size 445505 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..db8c879f1c809cdb3d6dc92b158faa65322280c2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/87486-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eafb75f2320441831d0eb96919c4b713619127567b39a31ca53bbf5993b91c37 +size 433700 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..10f10e18ba958080eab048bfea07efc932b41049 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ae7620f50e27d890e287b7e49389514a6ab2766579bda6f78eaa97e28cf1710 +size 464505 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..89c9ebb86d50eaad48b91bd62625a822cc0349e3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b13741c9bdc8ff7e0c12565b674cdefe41f2cfe4577beb1a573b2c4012c0c4a5 +size 445045 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..870853dd0312ecbcb38a594cc076953eb510be4a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e3a52e3a67c012a9f9541407a3c17450f8dc3293f1ae779642a34a48a53fc3c9 +size 453018 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..5631f57d460e88409e44cf015d5c7ba35abf25da --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c32f7dee3dd9f0411d9f249111ac80867dee7d2c7d67843b8b368f309521168d +size 444760 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..b31227e8ad20a5b49deaef7726e4c064f18541ff --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd3d275890556a43408fdfc835960bcdbaf2ad5b373f7afd7d83424066e203ec +size 445637 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..f46d102dd54aa49898414a4d77730b052de20bf0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/91652-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15f9a7943645b9f083e1019791f5312b630f9c4bfd9e89bdce0a16e4d4d3d8e1 +size 431660 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..c1efbb2dbb77fa60e0381dd9cc45c445d21ec921 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0943f962999c44ba0a22b11f9cef0dc9f4b05a6a9f77f3ccff52d1febdfd78ef +size 464798 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..077a97f756113f825ff3ba40deedd6878885d520 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4005dd104a9893a85a5cec94b19a5e79ec35558f01de2a11a75932907b92c9ef +size 441079 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..1fadcaaadf3429469f39ec2a395ed1c2c9e3d5fb --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba16490954a9bf776bf21ccc04feae7e05bd5ff945e2e000c0e475f416fd3d3f +size 453822 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..c95a06820dbc408ee75709149a6509ef829f4d2a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ed125a26579291ca553890c79623461d105aafcc6886fe2796dd5cc020ab2fe +size 444771 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..4a7848122f054cc6d830d05d2f6409cdd4af98fd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:836295ea766bccf35435d447b3bd529ee84124f47f29e59d31a2c8d015cad2de +size 445118 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..2198c56c36309aa1d5014671ebf958c0b0aea041 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/95818-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:298e427c8e8437de26b2d8170f1bb369e402be38f9d6d733dd2819f94fc4a939 +size 432275 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-0-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-0-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..ef23ee83920dacfaf1bf0461f2bb6f0cc49a3010 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-0-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d3f91643a81199013d972205579373f47f323c4ee8533db09cb967d87537691 +size 463708 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-0-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-0-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..cbd53e877ea95c380a350f559632795b287442ef --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-0-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5973fe98077f23d24056a2a73ce2b7e42d6306c5e93b50512e9b9a239931d204 +size 439484 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-1-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-1-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..683d249959370df73a78b1ca5420e420b8014433 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-1-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1604f769fde73061c8c394412517072742adb7cef4f7ab1ebe969173e473ddcc +size 453689 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-1-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-1-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..ea32a455876066724b7a09b3db0161124fdca1ee --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-1-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb8cc629ea4499da658f11012c7dd8a429a942e7b63c34dc11736ac51d5e89a6 +size 448369 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-2-compare.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-2-compare.png new file mode 100644 index 0000000000000000000000000000000000000000..8468ff4097972c9b36c7088f4c981d855f04758a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-2-compare.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9cafef5f9513a701b3506c8b1872ab3e41f17d3473e72353cc76bd83244e06f8 +size 445973 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-2-middle.png b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-2-middle.png new file mode 100644 index 0000000000000000000000000000000000000000..6af8ac72333a53921fcb98652e76a8ea0942d195 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/image_results/fastmri_8X/99984-2-middle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0f1ff662a298bc66d55783680d7311544c087e89c29c21e197af8ce202e567ba +size 432372 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/knee_data_split/singlecoil_train_split_less.csv b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/knee_data_split/singlecoil_train_split_less.csv new file mode 100644 index 0000000000000000000000000000000000000000..d85707318750900b14a6e7100541242a60b7a310 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/knee_data_split/singlecoil_train_split_less.csv @@ -0,0 +1,227 @@ +file1000685,file1000568,0.301723929779229 +file1002273,file1000481,0.302226224199571 +file1000472,file1000142,0.304272730770318 +file1002186,file1000863,0.304812175768496 +file1002385,file1002518,0.305357274240413 +file1000981,file1000129,0.305533361411383 +file1001320,file1001948,0.306821514316368 +file1000633,file1002243,0.306892354331709 +file1001872,file1001294,0.308345907393103 +file1001474,file1001830,0.310481695157561 +file1001005,file1001283,0.310497722435023 +file1001690,file1001519,0.310709448786299 +file1002469,file1001811,0.31193137253455 +file1000914,file1000242,0.31237190359308 +file1002284,file1002012,0.315366393843169 +file1001721,file1001328,0.31735122361847 +file1000807,file1002334,0.320096908959039 +file1001944,file1002335,0.320272061156991 +file1002090,file1002431,0.320351887633851 +file1000499,file1002063,0.320786426659383 +file1001362,file1000509,0.32175341740359 +file1001421,file1000597,0.324291432700032 +file1000349,file1000321,0.324545110048573 +file1002123,file1001235,0.327142348994532 +file1001867,file1002086,0.328624781732941 +file1001007,file1001027,0.330759860300298 +file1001915,file1000088,0.331499371283099 +file1001661,file1000313,0.331905252950291 +file1000383,file1000307,0.339998107225229 +file1000116,file1000632,0.34069458535013 +file1002303,file1000173,0.343821267871409 +file1000306,file1001277,0.344751178043605 +file1000003,file1001922,0.346138116633394 +file1000109,file1000143,0.347632265547478 +file1001999,file1000115,0.348248659775587 +file1000089,file1000326,0.348964657514049 +file1001205,file1002232,0.349375610862454 +file1000557,file1000619,0.351305005151048 +file1001823,file1000778,0.352076809462453 +file1000806,file1001130,0.352659078122633 +file1000365,file1000351,0.352772816610486 +file1002374,file1001778,0.352974481603711 +file1002516,file1001910,0.359896103026675 +file1001200,file1000931,0.360070003966827 +file1001479,file1000952,0.360424533696936 +file1000850,file1001942,0.362632797518558 +file1001426,file1002143,0.363271909822866 +file1001304,file1001333,0.36404737582222 +file1000390,file1000518,0.364744579516818 +file1000830,file1002096,0.365897427529429 +file1000794,file1001856,0.365973692948894 +file1001266,file1001327,0.366395851089761 +file1001692,file1002352,0.36655953875445 +file1001564,file1001024,0.367284385415205 +file1001861,file1002050,0.36783497787384 +file1002066,file1002361,0.367964419694875 +file1001613,file1002087,0.368231014746024 +file1001931,file1000220,0.368847112914793 +file1000339,file1000554,0.370123905662701 +file1000754,file1002208,0.37031588493778 +file1001067,file1001956,0.371313060558732 +file1000101,file1001053,0.372141932838775 +file1002520,file1002409,0.372501194473693 +file1001459,file1001615,0.373295536945146 +file1001673,file1000508,0.376416667681519 +file1002201,file1001228,0.376680033570078 +file1000058,file1002449,0.376927627737029 +file1001748,file1001042,0.378067114701689 +file1001941,file1000376,0.37841176147662 +file1000801,file1002545,0.378423759459738 +file1000010,file1000535,0.38111194591455 +file1000882,file1002154,0.382223600234592 +file1001694,file1001297,0.382545161354354 +file1001992,file1002456,0.382664563820782 +file1001666,file1001773,0.382892588770697 +file1001629,file1002514,0.383417073960824 +file1002113,file1000738,0.385439884728523 +file1002221,file1000569,0.385903801966773 +file1002296,file1002117,0.387319754665673 +file1000693,file1001945,0.387855926202209 +file1001410,file1000223,0.391284037867147 +file1002071,file1001425,0.391497653794399 +file1002325,file1001259,0.391913965917762 +file1002430,file1001969,0.392256443856501 +file1002462,file1000708,0.393161981208355 +file1002358,file1001888,0.39427809496515 +file1000485,file1000753,0.395316199436001 +file1002357,file1001973,0.39564210237905 +file1002130,file1002041,0.395978941103639 +file1002569,file1000097,0.397496127623486 +file1002264,file1000148,0.397630184088734 +file1002381,file1001401,0.398105992102355 +file1000289,file1000585,0.399527637723015 +file1002368,file1001723,0.400243022234875 +file1002342,file1001319,0.400431803928825 +file1002170,file1001226,0.400632448147846 +file1001385,file1001758,0.400855988878681 +file1001732,file1002541,0.40091828863264 +file1001102,file1000762,0.400923140595936 +file1001470,file1000181,0.401353492516182 +file1000400,file1000884,0.401562860630016 +file1002293,file1002523,0.401800994807451 +file1000728,file1001654,0.402763341041675 +file1000582,file1001491,0.403451830806034 +file1000586,file1001521,0.403648293267187 +file1002287,file1001770,0.405194821414496 +file1000371,file1000159,0.405999000381268 +file1002356,file1002064,0.406519210876811 +file1000324,file1000590,0.407593694425997 +file1001622,file1001710,0.40759525378577 +file1002037,file1000403,0.407814136488744 +file1002444,file1000743,0.40943197761463 +file1001175,file1002088,0.410423663035312 +file1001391,file1000540,0.410854355646853 +file1002133,file1001186,0.411248429534111 +file1001229,file1001630,0.411355571792039 +file1002283,file1000402,0.411836769927671 +file1000627,file1000161,0.412089060388579 +file1001701,file1001402,0.412854774524637 +file1000795,file1000452,0.413448916432685 +file1000354,file1000947,0.41459642292987 +file1002043,file1002505,0.414863932355455 +file1001285,file1001113,0.418183757940871 +file1000170,file1001832,0.419441549204313 +file1002399,file1001500,0.419905873946513 +file1002439,file1000177,0.42054051043224 +file1001656,file1001217,0.420597020703942 +file1000296,file1000065,0.420845042251081 +file1000626,file1001623,0.42087934790355 +file1001767,file1000760,0.422315537515139 +file1000467,file1001246,0.422371268999111 +file1001033,file1000611,0.42425275873442 +file1002304,file1000221,0.425602179771197 +file1001737,file1001141,0.425716789218234 +file1001565,file1000559,0.426158561043574 +file1000249,file1000643,0.426541100077021 +file1002014,file1001109,0.426587840438723 +file1002006,file1000790,0.427829459781438 +file1000193,file1000750,0.428103808477214 +file1001993,file1001110,0.428186367615143 +file1002094,file1001814,0.428868578868176 +file1000098,file1001420,0.428968675677784 +file1000336,file1000211,0.430347427208789 +file1001498,file1002568,0.43204475404071 +file1001671,file1001106,0.432215802861284 +file1000426,file1002386,0.43283446816702 +file1001520,file1002481,0.434867670495723 +file1002189,file1001432,0.434924370194975 +file1001390,file1002554,0.435313848731387 +file1002166,file1001982,0.435387512979012 +file1001120,file1001006,0.435594761785839 +file1000149,file1001985,0.436289528591294 +file1001632,file1001008,0.436682374331417 +file1002567,file1001155,0.437221000601772 +file1000434,file1002195,0.438098100114814 +file1002532,file1001048,0.438500899539101 +file1001605,file1000927,0.438686659342641 +file1000479,file1000120,0.439587267995034 +file1002473,file1001388,0.439594997597548 +file1001108,file1002228,0.440528754793898 +file1002099,file1002056,0.440776843467602 +file1000191,file1002127,0.441114509542672 +file1000875,file1002494,0.441378135507993 +file1002161,file1000002,0.441912476744187 +file1002269,file1001220,0.442742296865228 +file1001295,file1001355,0.4435162405589 +file1001659,file1001023,0.444686151316673 +file1001857,file1001378,0.447500830900898 +file1001183,file1001370,0.447782748040587 +file1000428,file1000859,0.448328910257083 +file1000588,file1002227,0.448650488897259 +file1001098,file1000486,0.448862467740607 +file1001288,file1000408,0.450363676957042 +file1002097,file1001210,0.451126832474666 +file1000216,file1001082,0.451550143520946 +file1001746,file1001642,0.451781042569196 +file1002388,file1000204,0.451940333555972 +file1000021,file1000560,0.452234621797968 +file1000489,file1001545,0.452796032302523 +file1001116,file1000883,0.453096911915119 +file1001372,file1000561,0.45532542913335 +file1001276,file1000424,0.45534174289324 +file1000974,file1002098,0.455371894001872 +file1002566,file1002044,0.455937677517583 +file1000262,file1002046,0.456056330767294 +file1001619,file1001342,0.456559091350965 +file1000045,file1001616,0.457599407743834 +file1001468,file1002115,0.458095965024278 +file1001061,file1000233,0.460561351667266 +file1000558,file1000100,0.461094222462111 +file1000605,file1000691,0.461429521647285 +file1000640,file1000384,0.463383466503099 +file1000410,file1001358,0.463452482427773 +file1000851,file1001014,0.463558384057952 +file1001092,file1000138,0.463591264436099 +file1000061,file1002049,0.465778207162619 +file1001206,file1000983,0.466701211830884 +file1000256,file1000475,0.466865377968187 +file1002434,file1001387,0.467154181996099 +file1001036,file1000210,0.470404279499276 +file1001540,file1001860,0.472822271037545 +file1001244,file1001154,0.475076170733515 +file1000131,file1001526,0.475459563440874 +file1000180,file1002045,0.476814451110009 +file1001837,file1000637,0.478851985878026 +file1002425,file1001891,0.481451070031007 +file1001056,file1000682,0.482320170742015 +file1002276,file1000777,0.483452141843029 +file1001139,file1002544,0.487462418948035 +file1000548,file1001257,0.488098081542811 +file1000188,file1001286,0.488423105111001 +file1001879,file1000999,0.488449105381724 +file1001062,file1000231,0.48930683373911 +file1000040,file1001873,0.492070802214623 +file1002286,file1000066,0.493213986773381 +file1002474,file1002563,0.501584439120211 +file1000967,file1000563,0.502066261411662 +file1001307,file1002048,0.50460435259807 +file1000483,file1001699,0.511819026566198 +file1001528,file1000285,0.512629017841038 +file1001742,file1002371,0.513805213204644 +file1002397,file1000592,0.515406473057 +file1000069,file1000510,0.528220553613126 +file1001087,file1001300,0.536510449049583 +file1001991,file1000836,0.538145797125916 +file1001382,file1001806,0.538539506621535 +file1000111,file1001189,0.557690760784602 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/knee_data_split/singlecoil_val_split_less.csv b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/knee_data_split/singlecoil_val_split_less.csv new file mode 100644 index 0000000000000000000000000000000000000000..a1cbac5537562063359f4ac3e0985de51cb989b2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/knee_data_split/singlecoil_val_split_less.csv @@ -0,0 +1,45 @@ +file1000323,file1002538,0.30754967523156 +file1001458,file1001566,0.310512744537048 +file1000885,file1001059,0.318226346221521 +file1000464,file1000196,0.321465466968232 +file1000314,file1000178,0.327505552363568 +file1001163,file1001289,0.328954963947692 +file1000033,file1001191,0.330925609207301 +file1000976,file1000990,0.344036229323198 +file1001930,file1001834,0.345994076497818 +file1002546,file1001344,0.351762252794677 +file1000277,file1001429,0.353297786572139 +file1001893,file1001262,0.358064285890878 +file1000926,file1002067,0.360639004205491 +file1001650,file1002002,0.362186928073579 +file1001184,file1001655,0.362592305723707 +file1001497,file1001338,0.365599407221502 +file1001202,file1001365,0.3844323497275 +file1001126,file1002340,0.388929627976346 +file1001339,file1000291,0.391300537691403 +file1002187,file1001862,0.39883786878841 +file1000041,file1000591,0.39896683485823 +file1001064,file1001850,0.399687813966601 +file1001331,file1002214,0.400340820924839 +file1000831,file1000528,0.403582747590964 +file1000769,file1000538,0.405298051020298 +file1000182,file1001968,0.407646172205036 +file1002382,file1001651,0.410749052045234 +file1000660,file1000476,0.415423894745454 +file1002570,file1001726,0.424622351472032 +file1001585,file1000858,0.426738511964108 +file1000190,file1000593,0.428080574167047 +file1001170,file1001090,0.429987089825525 +file1002252,file1001440,0.432038842370013 +file1000697,file1001144,0.432558506761396 +file1001077,file1000000,0.441922503777368 +file1001381,file1001119,0.455418270809002 +file1001759,file1001851,0.460824505737749 +file1000635,file1002389,0.465674267492171 +file1001668,file1001689,0.467330511330772 +file1001221,file1000818,0.469630000354232 +file1001298,file1002145,0.473526387887779 +file1001763,file1001938,0.47398893150184 +file1001444,file1000942,0.48507438696692 +file1000735,file1002007,0.496530240691134 +file1000477,file1000280,0.528508000547834 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/best_checkpoint.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/best_checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..89a7b3aa18bce7f2161aaac7bbb89072c653423c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/best_checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4852e8bc8d807c5869b6274e767661b03c719534618b737bd8d6133f9d0838ef +size 56614874 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_100000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_100000.pth new file mode 100644 index 0000000000000000000000000000000000000000..566857fbd458f90f2e08da1bd06e9053a61efaa5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_100000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:550e68895de54e89ad303fd299a8a297e8f4d4e2058d9a1a92c98cae9ca826be +size 56611226 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_120000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_120000.pth new file mode 100644 index 0000000000000000000000000000000000000000..6fc92ef9f5064825375e6f82495580185ab816fe --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_120000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c8fecf134d43e53fcdc52189cbac52ffc84c12406b2f728cfdb602342402697 +size 56611226 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_140000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_140000.pth new file mode 100644 index 0000000000000000000000000000000000000000..a2fbf41c4bdf850ac99444211204ec676b1663c4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_140000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5efbd8fa694200b15850aa0a43ecfbd127de80c1e69095f798f4f329b7ba2e3f +size 56611226 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_160000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_160000.pth new file mode 100644 index 0000000000000000000000000000000000000000..5892fc13a6f4e779347bc31b6661426ce1cacbea --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_160000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c71ce5e817533eb5f97438e422cc1021b1be12db2eee226f57917585c930c19c +size 56611226 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_20000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_20000.pth new file mode 100644 index 0000000000000000000000000000000000000000..b5793062616c44897a1b705879f41604e779bd7d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_20000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a773e0c55ba60882f174cc1f58907f635ce6a76849934caf07a6ec126559be3 +size 56610314 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_40000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_40000.pth new file mode 100644 index 0000000000000000000000000000000000000000..4cec6a4957d540441e10674ab2205a05ddcce439 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_40000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26c111ee6b264932eca5bc316ad2888a3f1587acd0bdb3fd101637168a43cb35 +size 56610314 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_60000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_60000.pth new file mode 100644 index 0000000000000000000000000000000000000000..d72edb7b48a23431840fa777933749e946cc7c03 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_60000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b962de9f09fa13987fb49b141104f3e847f2a268c7f7b3baa779c9dc5b4ef5e +size 56610314 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_80000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_80000.pth new file mode 100644 index 0000000000000000000000000000000000000000..f78f7c85f77c95205d75a96c8f800a876b10abed --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/iter_80000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b12db11944abdca8e57cc433dd45ada8da20f5a77fb79d9f0eecc9f88b3f8b72 +size 56610314 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/log.txt b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..40d6c26e57a1c44a84922cdb4c02c9d3d60a8301 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/log.txt @@ -0,0 +1,129 @@ +[02:43:59.934] Namespace(root_path='/home/v-qichen3/MRI_recon/data/fastmri', MRIDOWN='4X', low_field_SNR=0, phase='train', gpu='0', exp='FSMNet_fastmri_4x', max_iterations=200000, batch_size=2, base_lr=0.0001, seed=1337, resume=None, relation_consistency='False', clip_grad='True', grad_accum_steps=1, norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='random', CENTER_FRACTIONS=[0.08], ACCELERATIONS=[4], num_timesteps=5, image_size=320, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[02:45:36.907] Namespace(root_path='/datasdc16T/qichen/blob/qichen_blob/MRI_recon/data/fastmri', MRIDOWN='4X', low_field_SNR=0, phase='train', gpu='0', exp='FSMNet_fastmri_4x', max_iterations=200000, batch_size=2, base_lr=0.0001, seed=1337, resume=None, relation_consistency='False', clip_grad='True', grad_accum_steps=1, norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='random', CENTER_FRACTIONS=[0.08], ACCELERATIONS=[4], num_timesteps=5, image_size=320, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[02:46:10.001] Namespace(root_path='/home/v-qichen3/blob/qichen_blob/MRI_recon/data/fastmri', MRIDOWN='4X', low_field_SNR=0, phase='train', gpu='0', exp='FSMNet_fastmri_4x', max_iterations=200000, batch_size=2, base_lr=0.0001, seed=1337, resume=None, relation_consistency='False', clip_grad='True', grad_accum_steps=1, norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='random', CENTER_FRACTIONS=[0.08], ACCELERATIONS=[4], num_timesteps=5, image_size=320, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[03:30:47.609] Namespace(root_path='/home/v-qichen3/blob/qichen_blob/MRI_recon/data/fastmri', MRIDOWN='4X', low_field_SNR=0, phase='train', gpu='0', exp='FSMNet_fastmri_4x', max_iterations=200000, batch_size=2, base_lr=0.0001, seed=1337, resume=None, relation_consistency='False', clip_grad='True', grad_accum_steps=1, norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='random', CENTER_FRACTIONS=[0.08], ACCELERATIONS=[4], num_timesteps=5, image_size=320, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[06:39:02.661] +Epoch 0 Evaluation: +[06:40:05.278] average MSE: 0.055395592004060745 average PSNR: 27.894657764339996 average SSIM: 0.6802522455462984 +[09:18:48.340] +Epoch 1 Evaluation: +[09:19:47.763] average MSE: 0.04715090990066528 average PSNR: 28.719138381914124 average SSIM: 0.6993489941677059 +[11:28:51.740] +Epoch 2 Evaluation: +[11:29:48.108] average MSE: 0.045454103499650955 average PSNR: 28.92375825704511 average SSIM: 0.7046777659238931 +[13:51:04.617] +Epoch 3 Evaluation: +[13:52:02.312] average MSE: 0.043793123215436935 average PSNR: 29.131928554642442 average SSIM: 0.7127201530074037 +[15:23:18.419] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_20000.pth +[15:46:24.884] +Epoch 4 Evaluation: +[15:47:23.724] average MSE: 0.0428110770881176 average PSNR: 29.24158355752037 average SSIM: 0.7117530350941368 +[18:49:13.094] +Epoch 5 Evaluation: +[18:50:10.183] average MSE: 0.04247137904167175 average PSNR: 29.284277478629047 average SSIM: 0.7120290539084823 +[21:40:08.054] +Epoch 6 Evaluation: +[21:41:07.592] average MSE: 0.042885780334472656 average PSNR: 29.25133739144697 average SSIM: 0.713841478769817 +[23:27:10.807] +Epoch 7 Evaluation: +[23:28:09.362] average MSE: 0.042730510234832764 average PSNR: 29.297680455771864 average SSIM: 0.7143055570955252 +[01:59:01.349] +Epoch 8 Evaluation: +[01:59:58.537] average MSE: 0.042144566774368286 average PSNR: 29.35544952700144 average SSIM: 0.716946357779278 +[03:32:01.884] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_40000.pth +[04:45:19.018] +Epoch 9 Evaluation: +[04:46:15.972] average MSE: 0.04190925881266594 average PSNR: 29.379497378132424 average SSIM: 0.7149325949468087 +[07:01:29.345] +Epoch 10 Evaluation: +[07:02:26.647] average MSE: 0.041950855404138565 average PSNR: 29.3862971532986 average SSIM: 0.7155130461028231 +[09:48:03.486] +Epoch 11 Evaluation: +[09:48:58.906] average MSE: 0.04127601906657219 average PSNR: 29.46338663995586 average SSIM: 0.7135054797664598 +[11:54:03.641] +Epoch 12 Evaluation: +[11:54:59.419] average MSE: 0.04160526767373085 average PSNR: 29.429766785640194 average SSIM: 0.7156358769171677 +[14:53:03.223] +Epoch 13 Evaluation: +[15:09:07.567] average MSE: 0.04124404117465019 average PSNR: 29.47639666645126 average SSIM: 0.7161385386444223 +[15:53:06.553] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_60000.pth +[17:29:28.475] +Epoch 14 Evaluation: +[17:30:24.280] average MSE: 0.04155531898140907 average PSNR: 29.452468031310573 average SSIM: 0.7158154110879272 +[19:49:23.283] +Epoch 15 Evaluation: +[19:50:21.540] average MSE: 0.041638124734163284 average PSNR: 29.451690648093642 average SSIM: 0.7156771407207484 +[21:52:32.591] +Epoch 16 Evaluation: +[21:53:28.080] average MSE: 0.041410643607378006 average PSNR: 29.469163545230163 average SSIM: 0.7148224608238517 +[00:09:48.389] +Epoch 17 Evaluation: +[00:10:46.706] average MSE: 0.041037172079086304 average PSNR: 29.510351739428057 average SSIM: 0.716052565312664 +[02:11:53.354] +Epoch 18 Evaluation: +[02:12:52.093] average MSE: 0.041063591837882996 average PSNR: 29.506733835933726 average SSIM: 0.7154824060144455 +[02:35:10.288] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_80000.pth +[04:03:00.066] +Epoch 19 Evaluation: +[04:03:58.580] average MSE: 0.04100647568702698 average PSNR: 29.50608987280472 average SSIM: 0.7161391776241958 +[06:37:07.685] +Epoch 20 Evaluation: +[06:38:04.295] average MSE: 0.04089532420039177 average PSNR: 29.525126020792893 average SSIM: 0.716478199170695 +[09:05:17.798] +Epoch 21 Evaluation: +[09:06:14.710] average MSE: 0.040827859193086624 average PSNR: 29.528488234206872 average SSIM: 0.716142985256952 +[11:44:24.899] +Epoch 22 Evaluation: +[11:45:26.034] average MSE: 0.04078972712159157 average PSNR: 29.53480582746404 average SSIM: 0.7168395434823888 +[14:03:03.537] +Epoch 23 Evaluation: +[14:03:59.380] average MSE: 0.04082312807440758 average PSNR: 29.541606969359165 average SSIM: 0.7169637091926528 +[14:04:31.775] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_100000.pth +[16:31:04.823] +Epoch 24 Evaluation: +[16:32:03.791] average MSE: 0.04073634743690491 average PSNR: 29.546092238335273 average SSIM: 0.7171059677793982 +[18:59:16.427] +Epoch 25 Evaluation: +[19:00:15.776] average MSE: 0.04080486297607422 average PSNR: 29.538233997315114 average SSIM: 0.7169119520372566 +[21:48:31.612] +Epoch 26 Evaluation: +[21:49:28.916] average MSE: 0.040677882730960846 average PSNR: 29.553339036428856 average SSIM: 0.7172032084563584 +[23:58:31.755] +Epoch 27 Evaluation: +[23:59:31.847] average MSE: 0.04067990183830261 average PSNR: 29.55105232332524 average SSIM: 0.7173024428384114 +[01:35:13.079] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_120000.pth +[02:10:58.683] +Epoch 28 Evaluation: +[02:11:57.011] average MSE: 0.04063466191291809 average PSNR: 29.556597064858334 average SSIM: 0.7171909376945935 +[04:15:52.348] +Epoch 29 Evaluation: +[04:16:51.144] average MSE: 0.04067781940102577 average PSNR: 29.554275912518776 average SSIM: 0.7177268872095589 +[06:56:30.148] +Epoch 30 Evaluation: +[06:57:28.118] average MSE: 0.040646206587553024 average PSNR: 29.558332749595337 average SSIM: 0.717468852122194 +[09:53:19.173] +Epoch 31 Evaluation: +[09:54:23.533] average MSE: 0.040482811629772186 average PSNR: 29.574166511602055 average SSIM: 0.7168963214898668 +[12:56:02.747] +Epoch 32 Evaluation: +[12:56:58.658] average MSE: 0.040564555674791336 average PSNR: 29.56774236980243 average SSIM: 0.7176039977825167 +[14:06:49.741] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_140000.pth +[15:08:25.754] +Epoch 33 Evaluation: +[15:09:24.173] average MSE: 0.04059723764657974 average PSNR: 29.564359558833086 average SSIM: 0.7174119093262558 +[17:47:34.875] +Epoch 34 Evaluation: +[17:48:34.144] average MSE: 0.04060062766075134 average PSNR: 29.56416549935319 average SSIM: 0.7177988677156047 +[20:23:21.989] +Epoch 35 Evaluation: +[20:24:21.914] average MSE: 0.04056191444396973 average PSNR: 29.568903916429566 average SSIM: 0.7174859501313445 +[22:34:23.064] +Epoch 36 Evaluation: +[22:35:19.217] average MSE: 0.040594395250082016 average PSNR: 29.563233962971623 average SSIM: 0.717347662608815 +[01:55:26.233] +Epoch 37 Evaluation: +[01:56:25.247] average MSE: 0.040535397827625275 average PSNR: 29.573041633902772 average SSIM: 0.7175960175436993 +[02:57:58.294] save model to model/FSMNet_fastmri_4x_t5_kspace_time/iter_160000.pth +[04:08:37.257] +Epoch 38 Evaluation: +[04:09:37.383] average MSE: 0.04055703058838844 average PSNR: 29.569691942113327 average SSIM: 0.7175480348275984 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/log/events.out.tfevents.1755254815.GCRSANDBOX133 b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/log/events.out.tfevents.1755254815.GCRSANDBOX133 new file mode 100644 index 0000000000000000000000000000000000000000..269847faf98db5b34849ea58bacb058d50f00bf1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_4x_t5_kspace_time/log/events.out.tfevents.1755254815.GCRSANDBOX133 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:567a3e6f67cf4e05230e794ecc096d65d40fb461071d63684e56714cb1419f0a +size 40 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/best_checkpoint.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/best_checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..cd9b72bcd7c22e15cbfe32b7b0d20b3868bf98c1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/best_checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf12d3321f847ae2acf3a1448d02719cdb8340e2e50160999c7c8cf2516ddb2c +size 56614874 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_100000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_100000.pth new file mode 100644 index 0000000000000000000000000000000000000000..7612bfb90f03804bc62b53fcb09b5f0e72492f4d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_100000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1924c28a244ec0952b1aabb4666692c3d29d6b58f6438f2c6c0d8551c971ddff +size 56611226 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_120000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_120000.pth new file mode 100644 index 0000000000000000000000000000000000000000..2e47c914083657f73ace04f533ec2c3e23b7aed2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_120000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36fd3f821caa595e6a3ccc3c7af219e5cace60256551a0fc3968a9537239a3ca +size 56611226 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_140000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_140000.pth new file mode 100644 index 0000000000000000000000000000000000000000..04dcb8a8ea37ec764cb1695647112d6459884c08 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_140000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a34ad0711b6cc415fd8d57f967b8fdefdbb882e898c15640cc115acaf32c39b4 +size 56611226 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_160000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_160000.pth new file mode 100644 index 0000000000000000000000000000000000000000..e054e163967618ac55ad450d9a43c432e174d9b6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_160000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa3263e0439b55863d6f56b53b5321c92aff936e16c0c45c0199591b90bd4a89 +size 56611226 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_20000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_20000.pth new file mode 100644 index 0000000000000000000000000000000000000000..a54826befff23ad0cffb9af3c0770cf1d114106b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_20000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:081a39b59c44067c415654961b41fd2f21b28736883040f31d342c6a50ee3814 +size 56610314 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_40000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_40000.pth new file mode 100644 index 0000000000000000000000000000000000000000..2780061d0726591e538fff7703bcfcb6a62e37a9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_40000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ac8e1df2c1a7e2ddcd950a6beac8b3ea5bc57737208262b0250a28bdfe642c2 +size 56610314 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_60000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_60000.pth new file mode 100644 index 0000000000000000000000000000000000000000..58caa119ec389a48bbc9d0b5e8fdc74bb59cc5d5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_60000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bfc6142593333e2c5c29ba3695cbebccc54546a2dc05c0f31ff33083e6f9103d +size 56610314 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_80000.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_80000.pth new file mode 100644 index 0000000000000000000000000000000000000000..c3532fe74a69cc71b5ab7f53e2511ace3d9fa559 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/iter_80000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:129885e55b8dcf66aab7068e608d7847e34cfd2125e5741e6502f897261c0f82 +size 56610314 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/log.txt b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..f8d5cb9dca6d4a6ff087ee5666ae942a6130bc39 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/log.txt @@ -0,0 +1,132 @@ +[02:44:11.748] Namespace(root_path='/home/v-qichen3/MRI_recon/data/fastmri', MRIDOWN='8X', low_field_SNR=0, phase='train', gpu='1', exp='FSMNet_fastmri_8x', max_iterations=200000, batch_size=2, base_lr=0.0001, seed=1337, resume=None, relation_consistency='False', clip_grad='True', grad_accum_steps=1, norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='equispaced', CENTER_FRACTIONS=[0.04], ACCELERATIONS=[8], num_timesteps=5, image_size=320, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[02:45:51.390] Namespace(root_path='/datasdc16T/qichen/blob/qichen_blob/MRI_recon/data/fastmri', MRIDOWN='8X', low_field_SNR=0, phase='train', gpu='1', exp='FSMNet_fastmri_8x', max_iterations=200000, batch_size=2, base_lr=0.0001, seed=1337, resume=None, relation_consistency='False', clip_grad='True', grad_accum_steps=1, norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='equispaced', CENTER_FRACTIONS=[0.04], ACCELERATIONS=[8], num_timesteps=5, image_size=320, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[02:46:12.787] Namespace(root_path='/home/v-qichen3/blob/qichen_blob/MRI_recon/data/fastmri', MRIDOWN='8X', low_field_SNR=0, phase='train', gpu='1', exp='FSMNet_fastmri_8x', max_iterations=200000, batch_size=2, base_lr=0.0001, seed=1337, resume=None, relation_consistency='False', clip_grad='True', grad_accum_steps=1, norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='equispaced', CENTER_FRACTIONS=[0.04], ACCELERATIONS=[8], num_timesteps=5, image_size=320, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[03:30:51.402] Namespace(root_path='/home/v-qichen3/blob/qichen_blob/MRI_recon/data/fastmri', MRIDOWN='8X', low_field_SNR=0, phase='train', gpu='1', exp='FSMNet_fastmri_8x', max_iterations=200000, batch_size=2, base_lr=0.0001, seed=1337, resume=None, relation_consistency='False', clip_grad='True', grad_accum_steps=1, norm='False', input_normalize='mean_std', dist_url='63654', scale=8, base_num_every_group=2, rgb_range=255, n_colors=3, augment=False, fftloss=False, fftd=False, fftd_weight=0.1, fft_weight=0.01, model='MYNET', act='PReLU', data_range=1, num_channels=1, num_features=64, n_feats=64, res_scale=0.2, MASKTYPE='equispaced', CENTER_FRACTIONS=[0.04], ACCELERATIONS=[8], num_timesteps=5, image_size=320, distortion_sigma=0.0392156862745098, DEBUG=False, use_kspace=True, use_time_model=True, test_tag=None, test_sample='Ksample') +[06:29:42.601] +Epoch 0 Evaluation: +[06:30:42.170] average MSE: 0.06849547475576401 average PSNR: 26.981322521457056 average SSIM: 0.5469795436691905 +[08:35:38.534] +Epoch 1 Evaluation: +[08:36:36.387] average MSE: 0.0637686625123024 average PSNR: 27.3434317781291 average SSIM: 0.5636300727304532 +[11:15:16.164] +Epoch 2 Evaluation: +[11:16:14.566] average MSE: 0.06216821074485779 average PSNR: 27.463084212178273 average SSIM: 0.5646563056515655 +[13:41:17.683] +Epoch 3 Evaluation: +[13:42:15.944] average MSE: 0.06089402362704277 average PSNR: 27.589362141644415 average SSIM: 0.5731096715703814 +[15:12:26.605] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_20000.pth +[15:35:43.794] +Epoch 4 Evaluation: +[15:36:42.244] average MSE: 0.06020990014076233 average PSNR: 27.662149589825077 average SSIM: 0.5770954134092433 +[18:24:21.922] +Epoch 5 Evaluation: +[18:25:17.492] average MSE: 0.059002358466386795 average PSNR: 27.777219945177304 average SSIM: 0.5783281270612449 +[20:55:19.365] +Epoch 6 Evaluation: +[20:56:14.340] average MSE: 0.058727163821458817 average PSNR: 27.797634056298634 average SSIM: 0.5814021332562072 +[23:08:55.813] +Epoch 7 Evaluation: +[23:09:50.453] average MSE: 0.05836562439799309 average PSNR: 27.839798360031633 average SSIM: 0.5809144674504284 +[01:38:57.799] +Epoch 8 Evaluation: +[01:39:54.625] average MSE: 0.058350879698991776 average PSNR: 27.842734131451312 average SSIM: 0.5822874101908225 +[02:51:01.525] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_40000.pth +[04:08:10.554] +Epoch 9 Evaluation: +[04:09:04.715] average MSE: 0.05723675712943077 average PSNR: 27.95231869602451 average SSIM: 0.5843074759306646 +[06:38:25.004] +Epoch 10 Evaluation: +[06:39:20.330] average MSE: 0.0572502501308918 average PSNR: 27.947184290779123 average SSIM: 0.5856335530590412 +[09:10:38.382] +Epoch 11 Evaluation: +[09:11:35.753] average MSE: 0.057029902935028076 average PSNR: 27.98277787426478 average SSIM: 0.585356977794252 +[11:18:05.432] +Epoch 12 Evaluation: +[11:19:03.928] average MSE: 0.05679190531373024 average PSNR: 27.996928249238092 average SSIM: 0.5849190210160318 +[13:46:13.317] +Epoch 13 Evaluation: +[13:47:10.362] average MSE: 0.05672585964202881 average PSNR: 28.007734864863263 average SSIM: 0.586696863532803 +[15:11:50.805] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_60000.pth +[16:17:57.102] +Epoch 14 Evaluation: +[16:18:54.074] average MSE: 0.056276410818099976 average PSNR: 28.044647837953438 average SSIM: 0.5864728114342902 +[18:53:26.714] +Epoch 15 Evaluation: +[18:54:22.139] average MSE: 0.05619854852557182 average PSNR: 28.071009808795438 average SSIM: 0.5869360530801012 +[20:55:25.969] +Epoch 16 Evaluation: +[20:56:22.618] average MSE: 0.056052133440971375 average PSNR: 28.07746922787668 average SSIM: 0.5873917637082107 +[22:59:15.625] +Epoch 17 Evaluation: +[23:00:11.465] average MSE: 0.05619363859295845 average PSNR: 28.063145368492666 average SSIM: 0.5870929294202439 +[01:34:43.903] +Epoch 18 Evaluation: +[01:35:40.329] average MSE: 0.05639626085758209 average PSNR: 28.049744610800207 average SSIM: 0.5871393162434109 +[01:58:04.428] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_80000.pth +[03:25:54.032] +Epoch 19 Evaluation: +[03:26:50.663] average MSE: 0.055943187326192856 average PSNR: 28.088789247104277 average SSIM: 0.5880009169262379 +[05:41:48.785] +Epoch 20 Evaluation: +[05:42:45.921] average MSE: 0.056075435131788254 average PSNR: 28.079626473651846 average SSIM: 0.5876053683803258 +[08:18:38.268] +Epoch 21 Evaluation: +[08:19:31.716] average MSE: 0.05591992661356926 average PSNR: 28.09323764747977 average SSIM: 0.5884758394981815 +[10:46:49.560] +Epoch 22 Evaluation: +[10:47:47.404] average MSE: 0.055863138288259506 average PSNR: 28.09754619634455 average SSIM: 0.5883373255134637 +[13:16:49.214] +Epoch 23 Evaluation: +[13:17:48.170] average MSE: 0.05595538020133972 average PSNR: 28.087754953429272 average SSIM: 0.5883400094029779 +[13:18:29.345] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_100000.pth +[15:16:40.078] +Epoch 24 Evaluation: +[15:17:35.087] average MSE: 0.0558808334171772 average PSNR: 28.095824332130448 average SSIM: 0.587758044466449 +[18:03:47.624] +Epoch 25 Evaluation: +[18:04:42.746] average MSE: 0.05584316328167915 average PSNR: 28.102440971175394 average SSIM: 0.5885663817293969 +[20:21:47.598] +Epoch 26 Evaluation: +[20:22:47.221] average MSE: 0.0557008720934391 average PSNR: 28.114508492581702 average SSIM: 0.5887137689783611 +[23:11:45.266] +Epoch 27 Evaluation: +[23:12:42.367] average MSE: 0.05569044128060341 average PSNR: 28.11484317315951 average SSIM: 0.5888917997526566 +[00:47:56.137] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_120000.pth +[01:09:32.270] +Epoch 28 Evaluation: +[01:10:29.644] average MSE: 0.055704306811094284 average PSNR: 28.114929202486106 average SSIM: 0.5889343736368678 +[03:34:35.788] +Epoch 29 Evaluation: +[03:35:29.827] average MSE: 0.05570918321609497 average PSNR: 28.11546545245588 average SSIM: 0.5887347465431754 +[06:05:34.568] +Epoch 30 Evaluation: +[06:06:32.290] average MSE: 0.055728066712617874 average PSNR: 28.112451967406937 average SSIM: 0.5887477536609557 +[08:44:16.362] +Epoch 31 Evaluation: +[08:45:17.622] average MSE: 0.055712755769491196 average PSNR: 28.11672984576411 average SSIM: 0.5887804619973747 +[11:27:28.178] +Epoch 32 Evaluation: +[11:28:27.795] average MSE: 0.055720508098602295 average PSNR: 28.1144814222241 average SSIM: 0.5886445747487673 +[13:17:57.771] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_140000.pth +[14:03:43.232] +Epoch 33 Evaluation: +[14:04:44.134] average MSE: 0.05565781518816948 average PSNR: 28.12143983097114 average SSIM: 0.5889584049171944 +[16:42:41.017] +Epoch 34 Evaluation: +[16:43:36.256] average MSE: 0.055682916194200516 average PSNR: 28.118533189968854 average SSIM: 0.5888331281636389 +[19:37:37.794] +Epoch 35 Evaluation: +[19:38:35.324] average MSE: 0.05567517876625061 average PSNR: 28.120045678232838 average SSIM: 0.5891015421239777 +[21:46:36.674] +Epoch 36 Evaluation: +[21:47:34.263] average MSE: 0.055661678314208984 average PSNR: 28.11972343383556 average SSIM: 0.5888810876038033 +[00:27:08.941] +Epoch 37 Evaluation: +[00:28:05.869] average MSE: 0.05561193451285362 average PSNR: 28.12555343179277 average SSIM: 0.5889431283279748 +[01:53:15.420] save model to model/FSMNet_fastmri_8x_t5_kspace_time/iter_160000.pth +[03:16:56.436] +Epoch 38 Evaluation: +[03:17:51.932] average MSE: 0.05565853789448738 average PSNR: 28.122168483034176 average SSIM: 0.5890587148700736 +[05:22:49.759] +Epoch 39 Evaluation: +[05:23:45.697] average MSE: 0.05562131479382515 average PSNR: 28.12490446768612 average SSIM: 0.5890293074492906 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/log/events.out.tfevents.1755254815.GCRSANDBOX133 b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/log/events.out.tfevents.1755254815.GCRSANDBOX133 new file mode 100644 index 0000000000000000000000000000000000000000..112754f69bfd88a71d08696673f8b77caa37855b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/model/FSMNet_fastmri_8x_t5_kspace_time/log/events.out.tfevents.1755254815.GCRSANDBOX133 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:544b6f697fafdfc2fe3c8e082ba923ce88826ae83bd716a97cdd4dbc7681d0f2 +size 40 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/__pycache__/__init__.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2d7c1dda08d8786c287fa8c9ca38f4590eda1bc Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/__pycache__/__init__.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/__pycache__/common_freq.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/__pycache__/common_freq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8233f0e2831ef27ed72e3ce489affac08c6b22d6 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/__pycache__/common_freq.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/__pycache__/mynet.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/__pycache__/mynet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b0004b587dcc984530acbabcd036aa739413059 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/__pycache__/mynet.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/common_freq.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/common_freq.py new file mode 100644 index 0000000000000000000000000000000000000000..79cf3e778029a846b4da910c115c8315bf33dbaf --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/common_freq.py @@ -0,0 +1,389 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange +class Scale(nn.Module): + + def __init__(self, init_value=1e-3): + super().__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale + + +class invPixelShuffle(nn.Module): + def __init__(self, ratio=2): + super(invPixelShuffle, self).__init__() + self.ratio = ratio + + def forward(self, tensor): + ratio = self.ratio + b = tensor.size(0) + ch = tensor.size(1) + y = tensor.size(2) + x = tensor.size(3) + assert x % ratio == 0 and y % ratio == 0, 'x, y, ratio : {}, {}, {}'.format(x, y, ratio) + tensor = tensor.view(b, ch, y // ratio, ratio, x // ratio, ratio).permute(0, 1, 3, 5, 2, 4) + return tensor.contiguous().view(b, -1, y // ratio, x // ratio) + + +class UpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(in_channels=n_feats, out_channels=4 * n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PixelShuffle(upscale_factor=2)) + m.append(nn.PReLU()) + super(UpSampler, self).__init__(*m) + + +class InvUpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(invPixelShuffle(2)) + m.append(nn.Conv2d(in_channels=n_feats * 4, out_channels=n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PReLU()) + super(InvUpSampler, self).__init__(*m) + + + +class AdaptiveNorm(nn.Module): + def __init__(self, n): + super(AdaptiveNorm, self).__init__() + + self.w_0 = nn.Parameter(torch.Tensor([1.0])) + self.w_1 = nn.Parameter(torch.Tensor([0.0])) + + self.bn = nn.BatchNorm2d(n, momentum=0.999, eps=0.001) + + def forward(self, x): + return self.w_0 * x + self.w_1 * self.bn(x) + + +class ConvBNReLU2D(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, + bias=False, act=None, norm=None): + super(ConvBNReLU2D, self).__init__() + + self.layers = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + self.act = None + self.norm = None + if norm == 'BN': + self.norm = torch.nn.BatchNorm2d(out_channels) + elif norm == 'IN': + self.norm = torch.nn.InstanceNorm2d(out_channels) + elif norm == 'GN': + self.norm = torch.nn.GroupNorm(2, out_channels) + elif norm == 'WN': + self.layers = torch.nn.utils.weight_norm(self.layers) + elif norm == 'Adaptive': + self.norm = AdaptiveNorm(n=out_channels) + + if act == 'PReLU': + self.act = torch.nn.PReLU() + elif act == 'SELU': + self.act = torch.nn.SELU(True) + elif act == 'LeakyReLU': + self.act = torch.nn.LeakyReLU(negative_slope=0.02, inplace=True) + elif act == 'ELU': + self.act = torch.nn.ELU(inplace=True) + elif act == 'ReLU': + self.act = torch.nn.ReLU(True) + elif act == 'Tanh': + self.act = torch.nn.Tanh() + elif act == 'Sigmoid': + self.act = torch.nn.Sigmoid() + elif act == 'SoftMax': + self.act = torch.nn.Softmax2d() + + def forward(self, inputs): + + out = self.layers(inputs) + + if self.norm is not None: + out = self.norm(out) + if self.act is not None: + out = self.act(out) + return out + + +class ResBlock(nn.Module): + def __init__(self, num_features): + super(ResBlock, self).__init__() + self.layers = nn.Sequential( + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, act='ReLU', padding=1), + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, padding=1) + ) + + def forward(self, inputs): + return F.relu(self.layers(inputs) + inputs) + + +class DownSample(nn.Module): + def __init__(self, num_features, act, norm, scale=2): + super(DownSample, self).__init__() + if scale == 1: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + else: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + invPixelShuffle(ratio=scale), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + + def forward(self, inputs): + return self.layers(inputs) + + +class ResidualGroup(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act,norm, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [ + ResBlock(n_feat) for _ in range(n_resblocks)] + + modules_body.append(ConvBNReLU2D(n_feat, n_feat, kernel_size, padding=1, act=act, norm=norm)) + self.body = nn.Sequential(*modules_body) + self.re_scale = Scale(1) + + def forward(self, x): + res = self.body(x) + return res + self.re_scale(x) + + + +class FreBlock9(nn.Module): + def __init__(self, channels, args): + super(FreBlock9, self).__init__() + + self.fpre = nn.Conv2d(channels, channels, 1, 1, 0) + self.amp_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.pha_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.post = nn.Conv2d(channels, channels, 1, 1, 0) + + + def forward(self, x): + # print("x: ", x.shape) + _, _, H, W = x.shape + msF = torch.fft.rfft2(self.fpre(x)+1e-8, norm='backward') + + msF_amp = torch.abs(msF) + msF_pha = torch.angle(msF) + # print("msf_amp: ", msF_amp.shape) + amp_fuse = self.amp_fuse(msF_amp) + # print(amp_fuse.shape, msF_amp.shape) + amp_fuse = amp_fuse + msF_amp + pha_fuse = self.pha_fuse(msF_pha) + pha_fuse = pha_fuse + msF_pha + + real = amp_fuse * torch.cos(pha_fuse)+1e-8 + imag = amp_fuse * torch.sin(pha_fuse)+1e-8 + out = torch.complex(real, imag)+1e-8 + out = torch.abs(torch.fft.irfft2(out, s=(H, W), norm='backward')) + out = self.post(out) + out = out + x + out = torch.nan_to_num(out, nan=1e-5, posinf=1e-5, neginf=1e-5) + # print("out: ", out.shape) + return out + +class Attention(nn.Module): + def __init__(self, dim=64, num_heads=8, bias=False): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) + self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias) + self.q = nn.Conv2d(dim, dim , kernel_size=1, bias=bias) + self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x, y): + b, c, h, w = x.shape + + kv = self.kv_dwconv(self.kv(y)) + k, v = kv.chunk(2, dim=1) + q = self.q_dwconv(self.q(x)) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + +class FuseBlock7(nn.Module): + def __init__(self, channels): + super(FuseBlock7, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fre_att = Attention(dim=channels) + self.spa_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + ori = spa + fre = self.fre(fre) + spa = self.spa(spa) + fre = self.fre_att(fre, spa)+fre + spa = self.fre_att(spa, fre)+spa + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class FuseBlock6(nn.Module): + def __init__(self, channels): + super(FuseBlock6, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock7(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock7, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t1_att = Attention(dim=channels) + self.t2_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + t1 = self.t1_att(t1, t2)+t1 + t2 = self.t2_att(t2, t1)+t2 + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock6(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock6, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/ARTUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/ARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ae23fed16b0d9454a5ea09f17b4322f1e9d7e928 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/ARTUnet.py @@ -0,0 +1,751 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + # print("x shape:", x.shape) + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +########################################################################## +class CS_TransformerBlock(nn.Module): + r""" 在ART Transformer Block的基础上添加了channel attention的分支。 + 参考论文: Dual Aggregation Transformer for Image Super-Resolution + 及其代码: https://github.com/zhengchen1999/DAT/blob/main/basicsr/archs/dat_arch.py + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + + self.dwconv = nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim), + nn.InstanceNorm2d(dim), + nn.GELU()) + + self.channel_interaction = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(dim, dim // 8, kernel_size=1), + # nn.InstanceNorm2d(dim // 8), + nn.GELU(), + nn.Conv2d(dim // 8, dim, kernel_size=1)) + + self.spatial_interaction = nn.Sequential( + nn.Conv2d(dim, dim // 16, kernel_size=1), + nn.InstanceNorm2d(dim // 16), + nn.GELU(), + nn.Conv2d(dim // 16, 1, kernel_size=1)) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # print("x before dwconv:", x.shape) + # convolution output + conv_x = self.dwconv(x.permute(0, 3, 1, 2)) + # conv_x = x.permute(0, 3, 1, 2) + # print("x after dwconv:", x.shape) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + attened_x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + attened_x = attened_x[:, :H, :W, :].contiguous() + + attened_x = attened_x.view(B, H * W, C) + + # Adaptive Interaction Module (AIM) + # C-Map (before sigmoid) + channel_map = self.channel_interaction(conv_x).permute(0, 2, 3, 1).contiguous().view(B, 1, C) + # S-Map (before sigmoid) + attention_reshape = attened_x.transpose(-2,-1).contiguous().view(B, C, H, W) + spatial_map = self.spatial_interaction(attention_reshape) + + # C-I + attened_x = attened_x * torch.sigmoid(channel_map) + # S-I + conv_x = torch.sigmoid(spatial_map) * conv_x + conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, H * W, C) + + x = attened_x + conv_x + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/ARTUnet_new.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/ARTUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9d210d46e2a4d73781b3b16b63a53fdc0f25b9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/ARTUnet_new.py @@ -0,0 +1,823 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +Unet的格式构建image restoration 网络。每个transformer block中都是dense self-attention和sparse self-attention交替连接。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True, visualize_attention_maps=False): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + self.visualize_attention_maps = visualize_attention_maps + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + # assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + # qkv: 3, B_, num_heads, N, C // num_heads + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + # B_, num_heads, N, C // num_heads * + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + attn_map = attn.detach().cpu().numpy().mean(axis=1) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) # (B * nP, nHead, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + # attn: B * nP, num_heads, N, N + # v: B * nP, num_heads, N, C // num_heads + # -attn-@-v-> B * nP, num_heads, N, C // num_heads + # -transpose-> B * nP, N, num_heads, C // num_heads + # -reshape-> B * nP, N, C + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, G ** 2, G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, Gh * Gw, Gh * Gw) + else: + x = self.attn(x, Gh, Gw, mask=attn_mask) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +class ConcatTransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + ################## + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + ################## + + x = torch.cat((x, y), dim=1) + + shortcut = x + x = self.norm1(x) + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, 2 * Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, 2 * G ** 2, 2 * G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, 2 * Gh * Gw, 2 * Gh * Gw) + else: + x = self.attn(x, 2 * Gh, Gw, mask=attn_mask) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + output = x + + # print(x.shape, Gh, Gw) + x = output[:, :Gh * Gw, :] + y = output[:, Gh * Gw:, :] + assert x.shape == y.shape + + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = y.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = y.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, y, attn_map + else: + return x, y + + def get_patches(self, x, x_size): + H, W = x_size + B, H, W, C = x.shape + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + return x, attn_mask, (Hd, Wd), (Gh, Gw), (pad_r, pad_b), G, I + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/ART_Restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/ART_Restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..4c23123f0ac1a8e82e2bf907658ce8d91951c3e4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/ART_Restormer.py @@ -0,0 +1,592 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer串联而成。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[4, 4, 4, 4], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + + ######################################### Encoder level1 ############################################ + ### ART encoder block + for layer in self.ART_encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + + ### Restormer encoder block + out_enc_level1 = rearrange(out_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = self.Restormer_encoder_level1(out_enc_level1) + # stack.append(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.ART_encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + out_enc_level2 = rearrange(out_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = self.Restormer_encoder_level2(out_enc_level2) + # stack.append(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.ART_encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + out_enc_level3 = rearrange(out_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = self.Restormer_encoder_level3(out_enc_level3) + # stack.append(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.ART_latent: + latent = layer(latent, [H // 8, W // 8]) + + ### Restormer encoder block + latent = rearrange(latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = self.Restormer_latent(latent) + # stack.append(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + out_dec_level3 = inp_dec_level3 + for layer in self.ART_decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + out_dec_level3 = rearrange(out_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + out_dec_level3 = self.Restormer_decoder_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + out_dec_level2 = inp_dec_level2 + for layer in self.ART_decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + out_dec_level2 = rearrange(out_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + out_dec_level2 = self.Restormer_decoder_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + out_dec_level1 = inp_dec_level1 + for layer in self.ART_decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + ### Restormer decoder block + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_decoder_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_refinement(out_dec_level1) + + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/ART_Restormer_v2.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/ART_Restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..de8f921e81d79cb7af14a75326f0ddaaf3b3f4c3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/ART_Restormer_v2.py @@ -0,0 +1,663 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer并联而成。 +输入特征分别经过ART和Restormer两个分支提取特征, 将两者的特征concat之后,用conv层变换得到该block的输出特征。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.enc_fuse_level1 = nn.Conv2d(int(dim * 2 ** 1), dim, kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.enc_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.enc_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + self.enc_fuse_level4 = nn.Conv2d(int(dim * 2 ** 4), int(dim * 2 ** 3), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.dec_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.dec_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.dec_fuse_level1 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + # self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + # bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + # self.fuse_refinement = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + # def forward(self, inp_img, aux_img): + # stack = [] + # bs, _, H, W = inp_img.shape + + # fuse_img = torch.cat((inp_img, aux_img), 1) + # inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + + def forward(self, inp_img): + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + + ######################################### Encoder level1 ############################################ + art_enc_level1 = inp_enc_level1 + restor_enc_level1 = inp_enc_level1 + + ### ART encoder block + for layer in self.ART_encoder_level1: + art_enc_level1 = layer(art_enc_level1, [H, W]) + + ### Restormer encoder block + restor_enc_level1 = rearrange(restor_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_enc_level1 = self.Restormer_encoder_level1(restor_enc_level1) ### (b, c, h, w) + + art_enc_level1 = rearrange(art_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = torch.cat((art_enc_level1, restor_enc_level1), 1) + out_enc_level1 = self.enc_fuse_level1(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + art_enc_level2 = inp_enc_level2 + restor_enc_level2 = inp_enc_level2 + + for layer in self.ART_encoder_level2: + art_enc_level2 = layer(art_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + restor_enc_level2 = rearrange(restor_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + restor_enc_level2 = self.Restormer_encoder_level2(restor_enc_level2) + + art_enc_level2 = rearrange(art_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = torch.cat((art_enc_level2, restor_enc_level2), 1) + out_enc_level2 = self.enc_fuse_level2(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + art_enc_level3 = inp_enc_level3 + restor_enc_level3 = inp_enc_level3 + + for layer in self.ART_encoder_level3: + art_enc_level3 = layer(art_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + restor_enc_level3 = rearrange(restor_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + restor_enc_level3 = self.Restormer_encoder_level3(restor_enc_level3) + + art_enc_level3 = rearrange(art_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = torch.cat((art_enc_level3, restor_enc_level3), 1) + out_enc_level3 = self.enc_fuse_level3(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + art_latent = inp_enc_level4 + restor_latent = inp_enc_level4 + + for layer in self.ART_latent: + art_latent = layer(art_latent, [H // 8, W // 8]) + + ### Restormer encoder block + restor_latent = rearrange(restor_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + restor_latent = self.Restormer_latent(restor_latent) + + art_latent = rearrange(art_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = torch.cat((art_latent, restor_latent), 1) + latent = self.enc_fuse_level4(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + art_dec_level3 = inp_dec_level3 + restor_dec_level3 = inp_dec_level3 + + for layer in self.ART_decoder_level3: + art_dec_level3 = layer(art_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + restor_dec_level3 = rearrange(restor_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + restor_dec_level3 = self.Restormer_decoder_level3(restor_dec_level3) + + art_dec_level3 = rearrange(art_dec_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_dec_level3 = torch.cat((art_dec_level3, restor_dec_level3), 1) + out_dec_level3 = self.dec_fuse_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + art_dec_level2 = inp_dec_level2 + restor_dec_level2 = inp_dec_level2 + + for layer in self.ART_decoder_level2: + art_dec_level2 = layer(art_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + restor_dec_level2 = rearrange(restor_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + restor_dec_level2 = self.Restormer_decoder_level2(restor_dec_level2) + + art_dec_level2 = rearrange(art_dec_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_dec_level2 = torch.cat((art_dec_level2, restor_dec_level2), 1) + out_dec_level2 = self.dec_fuse_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + art_dec_level1 = inp_dec_level1 + restor_dec_level1 = inp_dec_level1 + + for layer in self.ART_decoder_level1: + art_dec_level1 = layer(art_dec_level1, [H, W]) + + ### Restormer decoder block + restor_dec_level1 = rearrange(restor_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_dec_level1 = self.Restormer_decoder_level1(restor_dec_level1) + + art_dec_level1 = rearrange(art_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = torch.cat((art_dec_level1, restor_dec_level1), 1) + out_dec_level1 = self.dec_fuse_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +# def build_model(args): +# return ART_Restormer( +# inp_channels=2, +# out_channels=1, +# dim=32, +# num_art_blocks=[2, 4, 4, 6], +# num_restormer_blocks=[2, 2, 2, 2], +# num_refinement_blocks=4, +# heads=[1, 2, 4, 8], +# window_size=[10, 10, 10, 10], mlp_ratio=4., +# qkv_bias=True, qk_scale=None, +# interval=[24, 12, 6, 3], +# ffn_expansion_factor = 2.66, +# LayerNorm_type = 'WithBias', ## Other option 'BiasFree' +# bias=False) + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/ARTfuse_layer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/ARTfuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..78af0b27528b50db897936f6b8b4eee51d566b1c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/ARTfuse_layer.py @@ -0,0 +1,705 @@ +""" +ART layer in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input to the Attention layer:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + # print("q range:", q.max(), q.min()) + # print("k range:", k.max(), k.min()) + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + # print("attention matrix:", attn.shape, attn) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + # print("feature before proj-layer:", x.max(), x.min()) + x = self.proj(x) + # print("feature after proj-layer:", x.max(), x.min()) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pre_norm=True): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + # self.conv = nn.Conv2d(dim, dim, kernel_size=1) + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + # self.conv = ConvBlock(self.dim, self.dim, drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.bn_norm = nn.BatchNorm2d(dim) + self.pre_norm = pre_norm + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + # x = self.conv(x) + shortcut = x + if self.pre_norm: + # print("using pre_norm") + x = self.norm1(x) ## 归一化之后特征范围变得非常小 + x = x.view(B, H, W, C) + # print("normalized input range:", x.max(), x.min(), x.mean()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + # print("range of feature before layer:", x.max(), x.min(), x.mean()) + # x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + + # x_bn = self.bn_norm(x.permute(0, 3, 1, 2)) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + # print("[ART layer] input and output feature difference:", torch.mean(torch.abs(shortcut - x))) + # print("range of feature before ART layer:", shortcut.max(), shortcut.min(), shortcut.mean()) + # print("range of feature after ART layer:", x.max(), x.min(), x.mean()) + if self.pre_norm: + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + # print("using post_norm") + x = self.norm1(shortcut + self.drop_path(x)) + x = self.norm2(x + self.drop_path(self.mlp(x))) + # x = shortcut + self.drop_path(x) + # x = x + self.drop_path(self.mlp(x)) + + # print("range of fused feature:", x.max(), x.min()) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + return self.layers(image) + + + +########################################################################## +class Cross_TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +class Cross_TransformerBlock_v2(nn.Module): + r""" ART Transformer Block. + 将两个模态的特征沿着channel方向concat, 一起取window. 之后把取出来的两个模态中同一个window的所有patches合并,送到transformer处理。 + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + xy = torch.cat((x, y), -1) + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + xy = F.pad(xy, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + xy = xy.reshape(B, Hd // G, G, Wd // G, G, 2*C).permute(0, 1, 3, 2, 4, 5).contiguous() + xy = xy.reshape(B * Hd * Wd // G ** 2, G ** 2, 2*C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + xy = xy.reshape(B, Gh, I, Gw, I, 2*C).permute(0, 2, 4, 1, 3, 5).contiguous() + xy = xy.reshape(B * I * I, Gh * Gw, 2*C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + x = xy[:, :, :C] + y = xy[:, :, C:] + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/DataConsistency.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/DataConsistency.py new file mode 100644 index 0000000000000000000000000000000000000000..5f460e9acabcb33e1a3f812eb9acaeb337da446c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/DataConsistency.py @@ -0,0 +1,48 @@ +""" +Created: DataConsistency @ Xiyang Cai, 2023/09/09 + +Data consistency layer for k-space signal. + +Ref: DataConsistency in DuDoRNet (https://github.com/bbbbbbzhou/DuDoRNet) + +""" + +import torch +from torch import nn +from einops import repeat + + +def data_consistency(k, k0, mask): + """ + k - input in k-space + k0 - initially sampled elements in k-space + mask - corresponding nonzero location + """ + out = (1 - mask) * k + mask * k0 + return out + + +class DataConsistency(nn.Module): + """ + Create data consistency operator + """ + + def __init__(self): + super(DataConsistency, self).__init__() + + def forward(self, k, k0, mask): + """ + k - input in frequency domain, of shape (n, nx, ny, 2) + k0 - initially sampled elements in k-space + mask - corresponding nonzero location (n, 1, len, 1) + """ + + if k.dim() != 4: # input is 2D + raise ValueError("error in data consistency layer!") + + # mask = repeat(mask.squeeze(1, 3), 'b x -> b x y c', y=k.shape[1], c=2) + mask = torch.tile(mask, (1, mask.shape[2], 1, k.shape[-1])) ### [n, 320, 320, 2] + # print("k and k0 shape:", k.shape, k0.shape) + out = data_consistency(k, k0, mask) + + return out, mask diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/DuDo_mUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/DuDo_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a56069a27f0cdd392482b41dc4e3b79252295487 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/DuDo_mUnet.py @@ -0,0 +1,359 @@ +""" +Xiaohan Xing, 2023/10/25 +传入两个模态的kspace data, 经过kspace network进行重建。 +然后把重建之后的kspace data变换回图像,然后送到image domain的网络进行图像重建。 +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + +from .Kspace_mUnet import kspace_mmUnet + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = kspace_ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = kspace_ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = kspace_ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = kspace_ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class DuDo_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 3, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + # self.kspace_net = mmConvKSpace() + self.kspace_net = kspace_mmUnet(args) + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + ) + ) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + stack = [] + feature_stack = [] + # output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + output = torch.cat((image, LR_image, aux_image), 1) #.detach() + + # print("LR image range:", LR_image.max(), LR_image.min()) + # print("kspace_recon image range:", image.max(), image.min()) + # print("aux_image range:", aux_image.max(), aux_image.min()) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output, image, recon_kspace, masked_kspace, data_stats + + + +class kspace_ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return DuDo_mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/DuDo_mUnet_ARTfusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/DuDo_mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6df67b63140ca51bd7b8664151ab4b1b7a6305d5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/DuDo_mUnet_ARTfusion.py @@ -0,0 +1,369 @@ +""" +Xiaohan Xing, 2023/11/08 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + ARTfusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + output1, output2 = aux_image.detach(), LR_image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, image, recon_kspace, masked_kspace, data_stats + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/DuDo_mUnet_CatFusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/DuDo_mUnet_CatFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c2b86b6e2031b1af07ab7cb58efc182b7a2673 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/DuDo_mUnet_CatFusion.py @@ -0,0 +1,338 @@ +""" +Xiaohan Xing, 2023/11/10 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + multi layer concat fusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_CATfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t2_gt: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + t2_image = t2_gt * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + # # print("aux_image range:", aux_image.max(), aux_image.min()) + # print("LR_image range:", LR_image.max(), LR_image.min()) + # print("kspace recon_img range:", image.max(), image.min()) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + t2_gt_image = self.normalize(t2_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + t2_gt_image = torch.clamp(t2_gt_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + # output1, output2 = aux_image.detach(), LR_image.detach() + output1, output2 = aux_image.detach(), image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, aux_image, t2_gt_image, image, LR_image, recon_kspace, masked_kspace, data_stats + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_CATfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/Kspace_ConvNet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/Kspace_ConvNet.py new file mode 100644 index 0000000000000000000000000000000000000000..918cbe598a62653394c8c4eaf99cd10a45c064ca --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/Kspace_ConvNet.py @@ -0,0 +1,140 @@ +""" +2023/10/05 +Xiaohan Xing +Build a simple network with several convolution layers. +Input: concat (undersampled kspace of FS-PD, fully-sampled kspace of PD modality) +Output: interpolated kspace of FS-PD +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, args, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + # self.conv_blocks = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # self.conv_blocks.append(ConvBlock(self.chans, self.chans * 2, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans * 2, self.chans, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans, self.out_chans, drop_prob)) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + # print("output of the three conv blocks:", output1.shape, output2.shape, output3.shape) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("conv1_output range:", output1.max(), output1.min()) + # print("conv2_output range:", output2.max(), output2.min()) + # print("conv3_output range:", output3.max(), output3.min()) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + # print("output before DC layer:", output.shape) + # print("output range before DC layer:", output.max(), output.min()) + # output = output + kspace ## residual connection + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + # mask_output = output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +def build_model(args): + return mmConvKSpace(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/Kspace_mUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/Kspace_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2db9a357798be4abd4cfbd44e9207555770b0456 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/Kspace_mUnet.py @@ -0,0 +1,202 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # print("output:", output.shape) + # print("kspace:", kspace.shape) + # print("mask:", mask.shape) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/Kspace_mUnet_new.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/Kspace_mUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ccfd28dc198ffeff2259a5fbed45dabdc76819 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/Kspace_mUnet_new.py @@ -0,0 +1,200 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # print("using Unet without Relu layers") + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + # output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # # print("output:", output.shape) + # # print("kspace:", kspace.shape) + # # print("mask:", mask.shape) + # mask_output = mask * output + + return output # .permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/MINet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..4eeeb063b12be1dc594e8e076e6a59e454fcfa0b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/MINet.py @@ -0,0 +1,356 @@ +""" +Compare with the method "Multi-contrast MRI Super-Resolution via a Multi-stage Integration Network". +The code is from https://github.com/chunmeifeng/MINet/edit/main/fastmri/models/MINet.py. +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F + + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 1 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + # print("modules_tail:", modules_tail) + # self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Sigmoid()) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + + + +class MINet(nn.Module): + def __init__(self, args, n_resgroups=3, n_resblocks=3, n_feats=64): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + # print("x1:", x1.shape, "x2:", x2.shape) + + ### 源代码中是super-resolution任务,所以self.tail对图像进行unsampling. 我们做reconstruction把这步去掉 + x2 = self.tail(x2) + + + resT1 = x1 + resT2 = x2 + + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + # print("t1s:", [item.shape for item in t1s]) + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + # print("resT1 range:", resT1.max(), resT1.min()) + # print("resT2 range:", resT2.max(), resT2.min()) + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1, x2 #x1=pd x2=pdfs + + +def build_model(args): + return MINet(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/MINet_common.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/MINet_common.py new file mode 100644 index 0000000000000000000000000000000000000000..631f3036db4edf8eab00d70c66d2058a3b65cd59 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/MINet_common.py @@ -0,0 +1,90 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias) + +class MeanShift(nn.Conv2d): + def __init__( + self, rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): + + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std + for p in self.parameters(): + p.requires_grad = False + +class BasicBlock(nn.Sequential): + def __init__( + self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, + bn=True, act=nn.ReLU(True)): + + m = [conv(in_channels, out_channels, kernel_size, bias=bias)] + if bn: + m.append(nn.BatchNorm2d(out_channels)) + if act is not None: + m.append(act) + + super(BasicBlock, self).__init__(*m) + +class ResBlock(nn.Module): + def __init__( + self, conv, n_feats, kernel_size, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + + return res + + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + # print("upsampler:", m) + + super(Upsampler, self).__init__(*m) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/SANet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/SANet.py new file mode 100644 index 0000000000000000000000000000000000000000..bfac3590e248c9532b002885c6c0b7a55ab1400a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/SANet.py @@ -0,0 +1,392 @@ +""" +This script contains the codes of "Exploring Separable Attention for Multi-Contrast MR Image Super-Resolution (TNNLS 2023)" +https://github.com/chunmeifeng/SANet/blob/main/fastmri/models/SANet.py +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import os + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,scale,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups #10 + self.n_resblocks = n_resblocks #20 + self.n_feats = n_feats #64 + kernel_size = 3 + reduction = 16 #16 + # scale = args.scale[0] + + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class Seatt(nn.Module): + def __init__(self, in_c): + super(Seatt, self).__init__() + self.reduce = nn.Conv2d(in_c * 2, 32, 1) + self.ff_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.bf_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.rgbd_pred_layer = Pred_Layer(32 * 2) + self.convq = nn.Conv2d(32, 32, 3, 1, 1) + + def forward(self, rgb_feat, dep_feat, pred): + feat = torch.cat((rgb_feat, dep_feat), 1) + + # vis_attention_map_rgb_feat(rgb_feat[0][0]) + # vis_attention_map_dep_feat(dep_feat[0][0]) + # vis_attention_map_feat(feat[0][0]) + + + feat = self.reduce(feat) + [_, _, H, W] = feat.size() + pred = torch.sigmoid(pred) + + ni_pred = 1 - pred + + ff_feat = self.ff_conv(feat * pred) + bf_feat = self.bf_conv(feat * (1 - pred)) + new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1)) + + return new_pred + +class SANet(nn.Module): + def __init__(self, args, scale=1, n_resgroups=3, n_resblocks=3, n_feats=64): + super(SANet, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + + cs = [64,64,64,64] + self.Seatts = nn.ModuleList([Seatt(c) for c in cs]) + + self.net1 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + print("nlayer:",nlayer) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=3, padding=1) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2, Seatt, map_conv in zip(self.net1.body._modules.items(), self.net2.body._modules.items(), self.Seatts, self.map_convs): + + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + pred = map_conv(resT1) + res = Seatt(resT1,resT2,pred) + + resT2 = res+resT2 + + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + + return x1,x2 + + +def build_model(args): + return SANet(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/SwinFuse_layer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/SwinFuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3dca8ba793d92961bc3f1d987e211fe542b303 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/SwinFuse_layer.py @@ -0,0 +1,1310 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + # print("x shape:", x.shape) + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + # print("num heads:", self.num_heads) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + # print("x_windows:", x_windows.shape) + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + + # # 不使用残差连接 + # x = self.residual_group_A(x, x_size) + # y = self.residual_group_B(y, x_size) + + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion_layer(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Fusion_num_heads=[6, 6], + window_size=7, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, img_range=1., resi_connection='1conv', + **kwargs): + super(SwinFusion_layer, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + + ### fusion layer. + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + # print("fusion layers:", len(self.layers_Fusion)) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + + def forward_features_Fusion(self, x, y): + + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + # x = torch.cat([x, y], 1) + # ## Downsample the feature in the channel dimension + # x = self.lrelu(self.conv_after_body_Fusion(x)) + # print("x range after fusion:", x.max(), x.min()) + # print("y range after fusion:", y.max(), y.min()) + # print("x:", x.shape) + + return x, y + + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + print("modality1 input:", x.max(), x.min()) + print("modality2 input:", y.max(), y.min()) + + # multi-modal feature fusion. + x, y = self.forward_features_Fusion(x, y) ### 返回两个模态融合之后的features. + print("modality1 output:", x.max(), x.min()) + print("modality2 output:", y.max(), y.min()) + + return x[:, :, :H, :W], y[:, :, :H, :W] + + + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + + +if __name__ == '__main__': + width, height = 120, 120 + swinfuse = SwinFusion_layer(img_size=(width, height), window_size=8, depths=[6, 6, 6], embed_dim=60, num_heads=[6, 6, 6]) + # print("SwinFusion layer:", swinfuse) + + A = torch.randn((4, 60, width, height)) + B = torch.randn((4, 60, width, height)) + + x = swinfuse(A, B) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/SwinFusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/SwinFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..3e72c2a3e8859423f3544043e8b9feaec002d0e2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/SwinFusion.py @@ -0,0 +1,1468 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # print("qkv after reshape:", qkv.shape) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + # 不使用残差连接 + # print("input x shape:", x.shape) + x = self.residual_group_A(x, x_size) + y = self.residual_group_B(y, x_size) + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Re_depths=[4], + Ex_num_heads=[6], Fusion_num_heads=[6, 6], Re_num_heads=[6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinFusion, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + print('in_chans: ', in_chans) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + ####修改shallow feature extraction 网络, 修改为2个3x3的卷积#### + self.conv_first1_A = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first1_B = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first2_A = nn.Conv2d(embed_dim_temp, embed_dim, 3, 1, 1) + self.conv_first2_B = nn.Conv2d(embed_dim_temp, embed_dim_temp, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + self.Re_num_layers = len(Re_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + dpr_Re = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Re_depths))] # stochastic depth decay rule + # build Residual Swin Transformer blocks (RSTB) + self.layers_Ex_A = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_A.append(layer) + self.norm_Ex_A = norm_layer(self.num_features) + + self.layers_Ex_B = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_B.append(layer) + self.norm_Ex_B = norm_layer(self.num_features) + + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + self.layers_Re = nn.ModuleList() + for i_layer in range(self.Re_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Re_depths[i_layer], + num_heads=Re_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Re[sum(Re_depths[:i_layer]):sum(Re_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Re.append(layer) + self.norm_Re = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last1 = nn.Conv2d(embed_dim, embed_dim_temp, 3, 1, 1) + self.conv_last2 = nn.Conv2d(embed_dim_temp, int(embed_dim_temp/2), 3, 1, 1) + self.conv_last3 = nn.Conv2d(int(embed_dim_temp/2), num_out_ch, 3, 1, 1) + + self.tanh = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features_Ex_A(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_A: + x = layer(x, x_size) + + x = self.norm_Ex_A(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward_features_Ex_B(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_B: + x = layer(x, x_size) + + x = self.norm_Ex_B(x) # B L C + x = self.patch_unembed(x, x_size) + return x + + def forward_features_Fusion(self, x, y): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + # print("x token:", x.shape, "y token:", y.shape) + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + x = torch.cat([x, y], 1) + ## Downsample the feature in the channel dimension + x = self.lrelu(self.conv_after_body_Fusion(x)) + + return x + + def forward_features_Re(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Re: + x = layer(x, x_size) + + x = self.norm_Re(x) # B L C + x = self.patch_unembed(x, x_size) + # Convolution + x = self.lrelu(self.conv_last1(x)) + x = self.lrelu(self.conv_last2(x)) + x = self.conv_last3(x) + return x + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + # self.mean_A = self.mean.type_as(x) + # self.mean_B = self.mean.type_as(y) + # self.mean = (self.mean_A + self.mean_B) / 2 + + # x = (x - self.mean_A) * self.img_range + # y = (y - self.mean_B) * self.img_range + + # Feedforward + x = self.forward_features_Ex_A(x) + y = self.forward_features_Ex_B(y) + # print("x before fusion:", x.shape) + # print("y before fusion:", y.shape) + x = self.forward_features_Fusion(x, y) + x = self.forward_features_Re(x) + x = self.tanh(x) + + # x = x / self.img_range + self.mean + # print("H:", H, "upscale:", self.upscale) + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='convs') + + +if __name__ == '__main__': + model = SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='convs') + # print(model) + + A = torch.randn((1, 1, 240, 240)) + B = torch.randn((1, 1, 240, 240)) + + x = model(A, B) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/TransFuse.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/TransFuse.py new file mode 100644 index 0000000000000000000000000000000000000000..56ca3d2b24f382603c876e4f6909363a9794ffa3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/TransFuse.py @@ -0,0 +1,155 @@ +import math +from collections import deque + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from torchvision import models + + +class SelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + """ + + def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(n_embd, n_embd) + self.query = nn.Linear(n_embd, n_embd) + self.value = nn.Linear(n_embd, n_embd) + # regularization + self.attn_drop = nn.Dropout(attn_pdrop) + self.resid_drop = nn.Dropout(resid_pdrop) + # output projection + self.proj = nn.Linear(n_embd, n_embd) + self.n_head = n_head + + def forward(self, x): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y + + +class Block(nn.Module): + """ an unassuming Transformer block """ + + def __init__(self, n_embd, n_head, block_exp, attn_pdrop, resid_pdrop): + super().__init__() + self.ln1 = nn.LayerNorm(n_embd) + self.ln2 = nn.LayerNorm(n_embd) + self.attn = SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + self.mlp = nn.Sequential( + nn.Linear(n_embd, block_exp * n_embd), + nn.ReLU(True), # changed from GELU + nn.Linear(block_exp * n_embd, n_embd), + nn.Dropout(resid_pdrop), + ) + + def forward(self, x): + B, T, C = x.size() + + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + + return x + + +class TransFuse_layer(nn.Module): + """ the full GPT language model, with a context size of block_size """ + + def __init__(self, n_embd, n_head, block_exp, n_layer, + num_anchors, seq_len=1, + embd_pdrop=0.1, attn_pdrop=0.1, resid_pdrop=0.1): + super().__init__() + self.n_embd = n_embd + self.seq_len = seq_len + self.vert_anchors = num_anchors + self.horz_anchors = num_anchors + + # positional embedding parameter (learnable), image + lidar + self.pos_emb = nn.Parameter(torch.zeros(1, 2 * seq_len * self.vert_anchors * self.horz_anchors, n_embd)) + + self.drop = nn.Dropout(embd_pdrop) + + # transformer + self.blocks = nn.Sequential(*[Block(n_embd, n_head, + block_exp, attn_pdrop, resid_pdrop) + for layer in range(n_layer)]) + + # decoder head + self.ln_f = nn.LayerNorm(n_embd) + + self.block_size = seq_len + self.apply(self._init_weights) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + + def forward(self, m1, m2): + """ + Args: + m1 (tensor): B*seq_len, C, H, W + m2 (tensor): B*seq_len, C, H, W + """ + + bz = m2.shape[0] // self.seq_len + h, w = m2.shape[2:4] + + # forward the image model for token embeddings + m1 = m1.view(bz, self.seq_len, -1, h, w) + m2 = m2.view(bz, self.seq_len, -1, h, w) + + # pad token embeddings along number of tokens dimension + token_embeddings = torch.cat([m1, m2], dim=1).permute(0,1,3,4,2).contiguous() + token_embeddings = token_embeddings.view(bz, -1, self.n_embd) # (B, an * T, C) + + # add (learnable) positional embedding for all tokens + x = self.drop(self.pos_emb + token_embeddings) # (B, an * T, C) + x = self.blocks(x) # (B, an * T, C) + x = self.ln_f(x) # (B, an * T, C) + x = x.view(bz, 2 * self.seq_len, self.vert_anchors, self.horz_anchors, self.n_embd) + x = x.permute(0,1,4,2,3).contiguous() # same as token_embeddings + + m1_out = x[:, :self.seq_len, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + m2_out = x[:, self.seq_len:, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + print("modality1 output:", m1_out.max(), m1_out.min()) + print("modality2 output:", m2_out.max(), m2_out.min()) + + return m1_out, m2_out + + +if __name__ == "__main__": + + feature1 = torch.randn((4, 512, 20, 20)) + feature2 = torch.randn((4, 512, 20, 20)) + model = TransFuse_layer(n_embd=512, n_head=4, block_exp=4, n_layer=8, num_anchors=20, seq_len=1) + print("TransFuse_layer:", model) + feat1, feat2 = model(feature1, feature2) + print(feat1.shape, feat2.shape) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/Unet_ART.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/Unet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b0f834270a921ae4eb428c92b3e782fbed19b3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/Unet_ART.py @@ -0,0 +1,243 @@ +""" +2023/07/06, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + output = conv_layer(output) + + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..334f1821a9b59e692641f70cdc61e7655b1a9c89 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/__init__.py @@ -0,0 +1,114 @@ +from .unet import build_model as UNET +from .mmunet import build_model as MUNET +from .humus_net import build_model as HUMUS +from .MINet import build_model as MINET +from .SANet import build_model as SANET +from .mtrans_net import build_model as MTRANS +from .unimodal_transformer import build_model as TRANS +from .cnn_transformer_new import build_model as CNN_TRANS +from .cnn import build_model as CNN +from .unet_transformer import build_model as UNET_TRANSFORMER +from .swin_transformer import build_model as SWIN_TRANS +from .restormer import build_model as RESTORMER + + +from .cnn_late_fusion import build_model as MULTI_CNN +from .mmunet_late_fusion import build_model as MMUNET_LATE +from .mmunet_early_fusion import build_model as MMUNET_EARLY +from .trans_unet.trans_unet import build_model as TRANSUNET +from .SwinFusion import build_model as SWINMULTI +from .mUnet_transformer import build_model as MUNET_TRANSFORMER +from .munet_multi_transfuse import build_model as MUNET_TRANSFUSE +from .munet_swinfusion import build_model as MUNET_SWINFUSE +from .munet_multi_concat import build_model as MUNET_CONCAT +from .munet_concat_decomp import build_model as MUNET_CONCAT_DECOMP +from .munet_multi_sum import build_model as MUNET_SUM +# from .mUnet_ARTfusion_new import build_model as MUNET_ART_FUSION +from .mUnet_ARTfusion import build_model as MUNET_ART_FUSION +from .mUnet_Restormer_fusion import build_model as MUNET_RESTORMER_FUSION +from .mUnet_ARTfusion_SeqConcat import build_model as MUNET_ART_FUSION_SEQ + +from .ARTUnet import build_model as ART +from .Unet_ART import build_model as UNET_ART +from .unet_restormer import build_model as UNET_RESTORMER +from .mmunet_ART import build_model as MMUNET_ART +from .mmunet_restormer import build_model as MMUNET_RESTORMER +from .mmunet_restormer_v2 import build_model as MMUNET_RESTORMER_V2 +from .mmunet_restormer_3blocks import build_model as MMUNET_RESTORMER_SMALL + +from .mARTUnet import build_model as ART_MULTI_INPUT + +from .ART_Restormer import build_model as ART_RESTORMER +from .ART_Restormer_v2 import build_model as ART_RESTORMER_V2 +from .swinIR import build_model as SWINIR + +from .Kspace_ConvNet import build_model as KSPACE_CONVNET +from .Kspace_mUnet_new import build_model as KSPACE_MUNET +from .kspace_mUnet_concat import build_model as KSPACE_MUNET_CONCAT +from .kspace_mUnet_AttnFusion import build_model as KSPACE_MUNET_ATTNFUSE + +from .DuDo_mUnet import build_model as DUDO_MUNET +from .DuDo_mUnet_ARTfusion import build_model as DUDO_MUNET_ARTFUSION +from .DuDo_mUnet_CatFusion import build_model as DUDO_MUNET_CONCAT + + +model_factory = { + 'unet_single': UNET, + 'humus_single': HUMUS, + 'transformer_single': TRANS, + 'cnn_transformer': CNN_TRANS, + 'cnn_single': CNN, + 'swin_trans_single': SWIN_TRANS, + 'trans_unet': TRANSUNET, + 'unet_transformer': UNET_TRANSFORMER, + 'restormer': RESTORMER, + 'unet_art': UNET_ART, + 'unet_restormer': UNET_RESTORMER, + 'art': ART, + 'art_restormer': ART_RESTORMER, + 'art_restormer_v2': ART_RESTORMER_V2, + 'swinIR': SWINIR, + + + 'munet_transformer': MUNET_TRANSFORMER, + 'munet_transfuse': MUNET_TRANSFUSE, + 'cnn_late_multi': MULTI_CNN, + 'unet_multi': MUNET, + 'unet_late_multi':MMUNET_LATE, + 'unet_early_multi':MMUNET_EARLY, + 'munet_ARTfusion': MUNET_ART_FUSION, + 'munet_restormer_fusion': MUNET_RESTORMER_FUSION, + + + 'minet_multi': MINET, + 'sanet_multi': SANET, + 'mtrans_multi': MTRANS, + 'swin_fusion': SWINMULTI, + 'munet_swinfuse': MUNET_SWINFUSE, + 'munet_concat': MUNET_CONCAT, + 'munet_concat_decomp': MUNET_CONCAT_DECOMP, + 'munet_sum': MUNET_SUM, + 'mmunet_art': MMUNET_ART, + 'mmunet_restormer': MMUNET_RESTORMER, + 'mmunet_restormer_small': MMUNET_RESTORMER_SMALL, + 'mmunet_restormer_v2': MMUNET_RESTORMER_V2, + + 'art_multi_input': ART_MULTI_INPUT, + + 'munet_ARTfusion_SeqConcat': MUNET_ART_FUSION_SEQ, + + 'kspace_ConvNet': KSPACE_CONVNET, + 'kspace_munet': KSPACE_MUNET, + 'kspace_munet_concat': KSPACE_MUNET_CONCAT, + 'kspace_munet_AttnFusion': KSPACE_MUNET_ATTNFUSE, + + 'dudo_munet': DUDO_MUNET, + 'dudo_munet_ARTfusion': DUDO_MUNET_ARTFUSION, + 'dudo_munet_concat': DUDO_MUNET_CONCAT, +} + + +def build_model_from_name(args): + assert args.model_name in model_factory.keys(), 'unknown model name' + + return model_factory[args.model_name](args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/cnn.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5c130178e7cdabaaef8686da0cec341265f3c43d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/cnn.py @@ -0,0 +1,246 @@ +""" +把our method中的单模态CNN_Transformer中的transformer换成几个conv_block用来下采样提取特征, 看是否能够达到和Unet相似的效果。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Reconstructor(nn.Module): + def __init__(self): + super(CNN_Reconstructor, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + nlatent = self.encoder.output_dim + self.decoder = Decoder(n_upsample=4, n_res=1, dim=nlatent, output_dim=1, pad_type='zero') + + + def forward(self, m): + #4,1,240,240 + feature_output = self.encoder.model(m) # [4, 256, 60, 60] + # print("output of the encoder:", feature_output.shape) + + output = self.decoder(feature_output) + + return output + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Reconstructor() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/cnn_late_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/cnn_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..58dacad6910b6e896d00a0d36c9ea60e8d77217f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/cnn_late_fusion.py @@ -0,0 +1,196 @@ +""" +在lequan的代码上,去掉transformer-based feature fusion部分。 +直接用CNN提取特征之后concat作为multi-modal representation. 用普通的decoder进行图像重建。 +把T1作为guidance modality, T2作为target modality. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers1 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + self.down_sample_layers2 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers1.append(ConvBlock(ch, ch * 2, drop_prob)) + self.down_sample_layers2.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + + # self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # ch = chans + # for _ in range(num_pool_layers - 1): + # self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + # ch *= 2 + # self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv = nn.Sequential(nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), nn.Tanh()) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = torch.cat([image, aux_image], 1) + # for layer in self.down_sample_layers: + # output = layer(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + + # apply down-sampling layers + output = image + for layer in self.down_sample_layers1: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + output_m1 = output + + output = aux_image + for layer in self.down_sample_layers2: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + output_m2 = output + + # print("two modality outputs:", output_m1.shape, output.shape) + output = torch.cat([output_m1, output_m2], 1) + + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv in self.up_transpose_conv: + output = transpose_conv(output) + + output = self.up_conv(output) + # print("model output:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/cnn_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/cnn_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..261d4d214e35c3a9be5d28bc426ce7280124723c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/cnn_transformer.py @@ -0,0 +1,289 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=2, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=2, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 60 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 4 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + feature_output = self.upsampling1(feature_output) # [4,512,30,30] + feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/cnn_transformer_new.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/cnn_transformer_new.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b2dea1e7cc8d456a684b6c839f052383999e84 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/cnn_transformer_new.py @@ -0,0 +1,290 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +2023/05/23: 之前是从尺寸为60*60的特征图上切patch, 现在可以尝试直接用conv_blocks提取15*15的特征图,然后按照patch_size = 1切成225个patches。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=4, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 1 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + # feature_output = self.upsampling1(feature_output) # [4,512,30,30] + # feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/humus_net.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/humus_net.py new file mode 100644 index 0000000000000000000000000000000000000000..d23701634f849b414866d71b3c23c6d07f3d6437 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/humus_net.py @@ -0,0 +1,1070 @@ +""" +HUMUS-Block +Hybrid Transformer-convolutional Multi-scale denoiser. + +We use parts of the code from +SwinIR (https://github.com/JingyunLiang/SwinIR/blob/main/models/network_swinir.py) +Swin-Unet (https://github.com/HuCaoFighting/Swin-Unet/blob/main/networks/swin_transformer_unet_skip_expand_decoder_sys.py) + +HUMUS-Net中是对kspace data操作, 多个unrolling cascade, 每个cascade中都利用HUMUS-block对图像进行denoising. +因为我们的方法输入是Under-sampled images, 所以不需要进行unrolling, 直接使用一个HUMUS-block进行MRI重建。 +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from einops import rearrange + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # print("x shape:", x.shape) + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + +class PatchExpandSkip(nn.Module): + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.expand = nn.Linear(dim, 2*dim, bias=False) + self.channel_mix = nn.Linear(dim, dim // 2, bias=False) + self.norm = norm_layer(dim // 2) + + def forward(self, x, skip): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) + x = x.view(B,-1,C//4) + x = torch.cat([x, skip], dim=-1) + x = self.channel_mix(x) + x = self.norm(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, torch.tensor(x_size)) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + block_type: + D: downsampling block, + U: upsampling block + B: bottleneck block + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv', block_type='B'): + super(RSTB, self).__init__() + self.dim = dim + conv_dim = dim // (patch_size ** 2) + divide_out_ch = 1 + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + ) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(conv_dim, conv_dim // divide_out_ch, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(conv_dim, conv_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // divide_out_ch, 3, 1, 1)) + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.block_type = block_type + if block_type == 'B': + self.reshape = torch.nn.Identity() + elif block_type == 'D': + self.reshape = PatchMerging(input_resolution, dim) + elif block_type == 'U': + self.reshape = PatchExpandSkip([res // 2 for res in input_resolution], dim * 2) + else: + raise ValueError('Unknown RSTB block type.') + + def forward(self, x, x_size, skip=None): + if self.block_type == 'U': + assert skip is not None, "Skip connection is required for patch expand" + x = self.reshape(x, skip) + + out = self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size)) + block_out = self.patch_embed(out) + x + + if self.block_type == 'D': + block_out = (self.reshape(block_out), block_out) # return skip connection + + return block_out + +class PatchEmbed(nn.Module): + r""" Fixed Image to Patch Embedding. Flattens image along spatial dimensions + without learned projection mapping. Only supports 1x1 patches. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + embed_dim (int): Number of output channels. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + +class PatchEmbedLearned(nn.Module): + r""" Image to Patch Embedding with arbitrary patch size and learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Linear(in_chans * patch_size[0] * patch_size[1], embed_dim) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.patch_to_vec(x) + x = self.proj(x) + if self.norm is not None: + x = self.norm(x) + return x + + def patch_to_vec(self, x): + """ + Args: + x: (B, C, H, W) + Returns: + patches: (B, num_patches, patch_size * patch_size * C) + """ + B, C, H, W = x.shape + x = x.permute(0, 2, 3, 1).contiguous() + x = x.view(B, H // self.patch_size[0], self.patch_size[0], W // self.patch_size[1], self.patch_size[1], C) + patches = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.patch_size[0], self.patch_size[1], C) + patches = patches.contiguous().view(B, self.num_patches, self.patch_size[0] * self.patch_size[1] * C) + return patches + + +class PatchUnEmbed(nn.Module): + r""" Fixed Image to Patch Unembedding via reshaping. Only supports patch size 1x1. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + +class PatchUnEmbedLearned(nn.Module): + r""" Patch Embedding to Image with learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans=None, out_chans=None, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.out_chans = out_chans + + self.proj_up = nn.Linear(in_chans, in_chans * patch_size[0] * patch_size[1]) + + def forward(self, x, x_size=None): + B, HW, C = x.shape + x = self.proj_up(x) + x = self.vec_to_patch(x) + return x + + def vec_to_patch(self, x): + B, HW, C = x.shape + x = x.view(B, self.patches_resolution[0], self.patches_resolution[1], C).contiguous() + x = x.view(B, self.patches_resolution[0], self.patch_size[0], self.patches_resolution[1], self.patch_size[1], C // (self.patch_size[0] * self.patch_size[1])).contiguous() + x = x.view(B, self.img_size[0], self.img_size[1], -1).contiguous().permute(0, 3, 1, 2) + return x + + +class HUMUSNet(nn.Module): + r""" HUMUS-Block + Args: + img_size (int | tuple(int)): Input image size. + patch_size (int | tuple(int)): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depths (tuple(int)): Depth of each Swin Transformer layer in encoder and decoder paths. + num_heads (tuple(int)): Number of attention heads in different layers of encoder and decoder. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + ape (bool): If True, add absolute position embedding to the patch embedding. + patch_norm (bool): If True, add normalization after patch embedding. + use_checkpoint (bool): Whether to use checkpointing to save memory. + img_range: Image range. 1. or 255. + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + conv_downsample_first: use convolutional downsampling before MUST to reduce compute load on Transformers + """ + + def __init__(self, args, + img_size=[240, 240], + in_chans=1, + patch_size=1, + embed_dim=66, + depths=[2, 2, 2], + num_heads=[3, 6, 12], + window_size=5, + mlp_ratio=2., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + img_range=1., + resi_connection='1conv', + bottleneck_depth=2, + bottleneck_heads=24, + conv_downsample_first=True, + out_chans=1, + no_residual_learning=False, + **kwargs): + super(HUMUSNet, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans if out_chans is None else out_chans + + self.center_slice_out = (out_chans == 1) + + self.img_range = img_range + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + self.conv_downsample_first = conv_downsample_first + self.no_residual_learning = no_residual_learning + + ##################################################################################################### + ################################### 1, input block ################################### + input_conv_dim = embed_dim + self.conv_first = nn.Conv2d(num_in_ch, input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** self.num_layers) + self.mlp_ratio = mlp_ratio + + # Downsample for low-res feature extraction + if self.conv_downsample_first: + img_size = [im //2 for im in img_size] + self.conv_down_block = ConvBlock(input_conv_dim // 2, input_conv_dim, 0.0) + self.conv_down = DownsampConvBlock(input_conv_dim, input_conv_dim) + + # split image into non-overlapping patches + if patch_size > 1: + self.patch_embed = PatchEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + if patch_size > 1: + self.patch_unembed = PatchUnEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, out_chans=embed_dim, norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # Build MUST + # encoder + self.layers_down = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** i_layer) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='D', + ) + self.layers_down.append(layer) + + # bottleneck + dim_scaler = (2 ** self.num_layers) + self.layer_bottleneck = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=bottleneck_depth, + num_heads=bottleneck_heads, + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='B', + ) + + # decoder + self.layers_up = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** (self.num_layers - i_layer - 1)) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[(self.num_layers-1-i_layer)], + num_heads=num_heads[(self.num_layers-1-i_layer)], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='U', + ) + self.layers_up.append(layer) + + self.norm_down = norm_layer(self.num_features) + self.norm_up = norm_layer(self.embed_dim) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(input_conv_dim, input_conv_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(input_conv_dim, input_conv_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim, 3, 1, 1)) + + # Upsample if needed + if self.conv_downsample_first: + self.conv_up_block = ConvBlock(input_conv_dim, input_conv_dim // 2, 0.0) + self.conv_up = TransposeConvBlock(input_conv_dim, input_conv_dim // 2) + + ##################################################################################################### + ################################ 3, output block ################################ + self.conv_last = nn.Conv2d(input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, num_out_ch, 3, 1, 1) + self.output_norm = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + + # divisible by window size + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + + # divisible by total downsampling ratio + # this could be done more efficiently by combining the two + total_downsamp = int(2 ** (self.num_layers - 1)) + pad_h = (total_downsamp - h % total_downsamp) % total_downsamp + pad_w = (total_downsamp - w % total_downsamp) % total_downsamp + x = F.pad(x, (0, pad_w, 0, pad_h)) + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + # encode + skip_cons = [] + for i, layer in enumerate(self.layers_down): + x, skip = layer(x, x_size=layer.input_resolution) + skip_cons.append(skip) + x = self.norm_down(x) # B L C + + # bottleneck + x = self.layer_bottleneck(x, self.layer_bottleneck.input_resolution) + + # decode + for i, layer in enumerate(self.layers_up): + x = layer(x, x_size=layer.input_resolution, skip=skip_cons[-i-1]) + x = self.norm_up(x) + x = self.patch_unembed(x, x_size) + return x + + def forward(self, x): + # print("input shape:", x.shape) + # print("input range:", x.max(), x.min()) + C, H, W = x.shape[1:] + center_slice = (C - 1) // 2 + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.conv_downsample_first: + x_first = self.conv_first(x) + x_down = self.conv_down(self.conv_down_block(x_first)) + res = self.conv_after_body(self.forward_features(x_down)) + res = self.conv_up(res) + res = torch.cat([res, x_first], dim=1) + res = self.conv_up_block(res) + + res = self.conv_last(res) + + if self.no_residual_learning: + x = res + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + res + else: + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + + if self.no_residual_learning: + x = self.conv_last(res) + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + self.conv_last(res) + + # print("output range before scaling:", x.max(), x.min()) + # print("self.mean:", self.mean) + + if self.center_slice_out: + x = x / self.img_range + self.mean[:, center_slice, ...].unsqueeze(1) + else: + x = x / self.img_range + self.mean + + x = self.output_norm(x) + + return x[:, :, :H, :W] + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + +class DownsampConvBlock(nn.Module): + """ + A Downsampling Convolutional Block that consists of one strided convolution + layer followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.Conv2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H/2, W/2)`. + """ + return self.layers(image) + + + +def build_model(args): + return HUMUSNet(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/kspace_mUnet_AttnFusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/kspace_mUnet_AttnFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..02588c24559c0604ddbbd03a14b3b32e6cff2c8d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/kspace_mUnet_AttnFusion.py @@ -0,0 +1,300 @@ +""" +Xiaohan Xing, 2023/11/22 +两个模态分别用CNN提取多个层级的特征, 每个层级都把特征沿着宽度拼接起来,得到(h, 2w, c)的特征。 +然后学习得到(h, 2w)的attention map, 和原始的特征相乘送到后面的层。 +""" + +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.attn_layers = nn.ModuleList() + + for l in range(self.num_pool_layers): + + self.attn_layers.append(nn.Sequential(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob), + nn.Conv2d(chans*(2**l), 2, kernel_size=1, stride=1), + nn.Sigmoid())) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, attn_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.attn_layers): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat之后求attention, 然后进行modulation. + attn = attn_layer(torch.cat((output1, output2), 1)) + # print("spatial attention:", attn[:, 0, :, :].unsqueeze(1).shape) + # print("output1:", output1.shape) + output1 = output1 * (1 + attn[:, 0, :, :].unsqueeze(1)) + output2 = output2 * (1 + attn[:, 1, :, :].unsqueeze(1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/kspace_mUnet_concat.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/kspace_mUnet_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..d16bfdc83305d3e1f490e5ca4b445c6d730557ce --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/kspace_mUnet_concat.py @@ -0,0 +1,351 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Predict the low-frequency mask ############## +class Mask_Predictor(nn.Module): + def __init__(self, in_chans, out_chans, chans, drop_prob): + super(Mask_Predictor, self).__init__() + self.conv1 = ConvBlock(in_chans, chans, drop_prob) + self.conv2 = ConvBlock(chans, chans*2, drop_prob) + self.conv3 = ConvBlock(chans*2, chans*4, drop_prob) + + # self.conv4 = ConvBlock(chans*4, 1, drop_prob) + + self.FC = nn.Linear(chans*4, out_chans) + self.act = nn.Sigmoid() + + + def forward(self, x): + ### 三层卷积和down-sampling. + # print("[Mask predictor] input data range:", x.max(), x.min()) + output = F.avg_pool2d(self.conv1(x), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer1_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv2(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer2_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv3(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer3_feature range:", output.max(), output.min()) + + feature = F.adaptive_avg_pool2d(F.relu(output), (1, 1)).squeeze(-1).squeeze(-1) + # print("mask_prediction feature:", output[:, :5]) + # print("[Mask predictor] bottleneck feature range:", feature.max(), feature.min()) + + output = self.FC(feature) + # print("output before act:", output) + output = self.act(output) + # print("mask output:", output) + + return output + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.mask_predictor1 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + self.mask_predictor2 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + ### 预测做low-freq DC的mask区域. + ### 输出三个维度, 前两维是坐标,最后一维是mask内部的权重 + t1_mask = self.mask_predictor1(output1) + t2_mask = self.mask_predictor2(output2) + + # print("t1_mask and t2_mask:", t1_mask.shape, t2_mask.shape) + t1_mask_coords, t2_mask_coords = t1_mask[:, :2], t2_mask[:, :2] + t1_DC_weight, t2_DC_weight = t1_mask[:, -1], t2_mask[:, -1] + + + # ### 预测做low-freq DC的DC weight + # t1_DC_weight = self.mask_predictor1(output1) + # t2_DC_weight = self.mask_predictor2(output2) + + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, t1_mask_coords, t2_mask_coords, t1_DC_weight, t2_DC_weight ## + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mARTUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9607d41ca14562a94e542a1804af5aeaf15bc123 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mARTUnet.py @@ -0,0 +1,563 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +2023/07/20, Xiaohan Xing +将两个模态的图像在输入端concat, 得到num_channels = 2的input, 然后送到ART模型进行T2 modality的重建。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + # print("feature after layer_norm:", x.max(), x.min()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=2, + out_channels=1, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img, aux_img): + stack = [] + bs, _, H, W = inp_img.shape + fuse_img = torch.cat((inp_img, aux_img), 1) + inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=2, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mUnet_ARTfusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..610b1e1e38d74f695a3a19f22a6a80f3152bf713 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mUnet_ARTfusion.py @@ -0,0 +1,305 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import CS_TransformerBlock as TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + # output1, output2 = image1, image2 + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mUnet_ARTfusion_SeqConcat.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mUnet_ARTfusion_SeqConcat.py new file mode 100644 index 0000000000000000000000000000000000000000..55280925415f1db47cf746b59fb9710d47c9bff2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mUnet_ARTfusion_SeqConcat.py @@ -0,0 +1,374 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any + +import matplotlib.pyplot as plt +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet_new import ConcatTransformerBlock, TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion_SeqConcat(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4. + ): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.vis_feature_maps = False + self.vis_attention_maps = False + # self.vis_feature_maps = args.MODEL.VIS_FEAT_MAPS + # self.vis_attention_maps = args.MODEL.VIS_ATTN_MAPS + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + if l < 2: + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * (2 ** (l + 1))), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + else: + self.fuse_layers.append(nn.ModuleList([ConcatTransformerBlock(dim=int(chans * (2 ** l)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + feature_maps = { + 'modal1': {'before': [], 'after': []}, + 'modal2': {'before': [], 'after': []} + } + + attention_maps = [] + """ + attention_maps = [ + unet layer 1 attention maps (x2): [map from TransBlock1, TransBlock2, ...], + unet layer 2 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 3 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 4 attention maps (x6): [map from TransBlock1, TransBlock2, ...], + ] + """ + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + if self.vis_feature_maps: + feature_maps['modal1']['before'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['before'].append(output2.detach().cpu().numpy()) + + ### 两个模态的特征concat + # output = torch.cat((output1, output2), 1) + + + + unet_layer_attn_maps = [] + """ + unet_layer_attn_maps: [ + attn_map from TransformerBlock1: (B, nHGroups, nWGroups, winSize, winSize), + attn_map from TransformerBlock2: (B, nHGroups, nWGroups, winSize, winSize), + ... + ] + """ + if l < 2: + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + else: + output1 = rearrange(output1, "b c h w -> b (h w) c").contiguous() + output2 = rearrange(output2, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + if l < 2: + if self.vis_attention_maps: + output, attn_map = layer(output, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output = layer(output, [H // (2**l), W // (2**l)]) + + else: + if self.vis_attention_maps: + output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + attention_maps.append(unet_layer_attn_maps) + + if l < 2: + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + else: + output1 = rearrange(output1, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output2 = rearrange(output2, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + # if self.vis_attention_maps: + # output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) # attn_map: (B, nHGroups, nWGroups, winSize, winSize) + # unet_layer_attn_maps.append(attn_map) + # else: + # output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + # attention_maps.append(unet_layer_attn_maps) + + + + if self.vis_feature_maps: + feature_maps['modal1']['after'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['after'].append(output2.detach().cpu().numpy()) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + if self.vis_feature_maps: + return output1, output2, feature_maps#, relation_stack1, relation_stack2 #, t1_features, t2_features + elif self.vis_attention_maps: + return output1, output2, attention_maps + else: + return output1, output2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion_SeqConcat(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mUnet_Restormer_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mUnet_Restormer_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8ab29b641fa0231ad0d163224a4a90b44c32a9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mUnet_Restormer_fusion.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .restormer import TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8]): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.Sequential(*[TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[0], ffn_expansion_factor=2.66, \ + bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + ### 将encoder中multi-modal fusion之后的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = fuse_layer(output) + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mUnet_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mUnet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2a209758203b198502154b75496333ef91717a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mUnet_transformer.py @@ -0,0 +1,387 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/06/08 +分别用CNN提取两个模态的特征,然后送到transformer进行特征交互, 将经过transformer提取的特征和之前CNN提取的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Decoder_wo_skip(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder_wo_skip, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + output = transpose_conv(output) + output = conv(output) + + return output + + + + + +class mUnet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch*2 + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1, output2 = image1, image2 + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack1, output1 = self.encoder1(output1) ### output size: [4, 256, 15, 15] + stack2, output2 = self.encoder2(output2) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output1) # [4, 225, 256], 225 is the number of patches. + patch_embed_input1 = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + n_embed = self.to_patch_embedding(output2) # [4, 225, 256], 225 is the number of patches. + patch_embed_input2 = F.normalize(n_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input1, patch_embed_input2), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1]/2)), int(np.sqrt(feature_output.shape[1]/2)) + patch_num = feature_output.shape[1] + feature_output1 = feature_output[:, :patch_num//2, :] + feature_output2 = feature_output[:, patch_num//2:, :] + feature_output1 = feature_output1.contiguous().view(b, feature_output1.shape[-1], h, w) # [4,c,15,15] + feature_output2 = feature_output2.contiguous().view(b, feature_output2.shape[-1], h, w) + + output = torch.cat((output1, output2), 1) + torch.cat((feature_output1, feature_output2), 1) + # print("CNN feature range:", output1.max(), output1.min()) + # print("Transformer feature range:", feature_output1.max(), feature_output1.min()) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_Transformer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ee0c4784d04638df4ca259613777dde34980a2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet.py @@ -0,0 +1,201 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/04 +Concatenate the input images from different modalities and input it into the UNet. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_ART.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..1a09f2900a8924cdc1a7c373270e6ff8aadcfa45 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_ART.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_ART_v2.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_ART_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..adfe6bc0d9192a126f39932b40443badb5b60068 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_ART_v2.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_early_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_early_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..d062f057a2080a471005125a00ee27453a4b0706 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_early_fusion.py @@ -0,0 +1,223 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +encapsulate the encoder and decoder for the Unet model. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + stack, output = self.encoder(output) + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + output = self.decoder(output, stack) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_late_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..2b48e7d4445afd6a5bc1eaa2b065c37431585e94 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_late_fusion.py @@ -0,0 +1,229 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +use Unet for the reconstruction of each modality, concat the intermediate features of both modalities. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("decoder output:", output.shape) + + return output + + + + +class mmUnet_late(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack1, output1 = self.encoder1(image) + stack2, output2 = self.encoder2(aux_image) + ### fuse two modalities to get multi-modal representation. + output = torch.cat([output1, output2], 1) + output = self.conv(output) ### from [4, 512, 15, 15] to [4, 512, 15, 15] + + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet_late(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..88c7920d65c58ab82a00309a1ae501a8d5b33804 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_restormer.py @@ -0,0 +1,237 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_restormer_3blocks.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_restormer_3blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e3afe68612727d3195872a121ec8040fb0f66ef2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_restormer_3blocks.py @@ -0,0 +1,235 @@ +""" +2023/07/18, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +之前 3 blocks的模型深层特征存在严重的gradient vanishing问题。现在把模型的encoder和decoder都改成只有3个blocks, 看是否还存在gradient vanishing问题。 +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 3, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2] + heads=[1, 2, 4] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_restormer_v2.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..4b51d3a218057206863230973d557173e891d3cd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mmunet_restormer_v2.py @@ -0,0 +1,220 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +这个版本中是将Unet encoder提取的特征直接连接到decoder, 经过restormer refine的特征只是传递到下一层的encoder block. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + output = feature + + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mtrans_mca.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mtrans_mca.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4c89df4af71e986fa7c24d522e6732371f3128 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mtrans_mca.py @@ -0,0 +1,119 @@ +import torch +from torch import nn + +from .mtrans_transformer import Mlp + +# implement by YunluYan + + +class LinearProject(nn.Module): + def __init__(self, dim_in, dim_out, fn): + super(LinearProject, self).__init__() + need_projection = dim_in != dim_out + self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity() + self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity() + self.fn = fn + + def forward(self, x, *args, **kwargs): + x = self.project_in(x) + x = self.fn(x, *args, **kwargs) + x = self.project_out(x) + return x + + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, dim, num_heads, attn_drop=0., proj_drop=0.): + super(MultiHeadCrossAttention, self).__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim, dim*2, bias=False) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + + def forward(self, x, complement): + + # x [B, HW, C] + B_x, N_x, C_x = x.shape + + x_copy = x + + complement = torch.cat([x, complement], 1) + + B_c, N_c, C_c = complement.shape + + # q [B, heads, HW, C//num_heads] + q = self.to_q(x).reshape(B_x, N_x, self.num_heads, C_x//self.num_heads).permute(0, 2, 1, 3) + kv = self.to_kv(complement).reshape(B_c, N_c, 2, self.num_heads, C_c//self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_x, N_x, C_x) + + x = x + x_copy + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossTransformerEncoderLayer(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio = 1., attn_drop=0., proj_drop=0.,drop_path = 0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super(CrossTransformerEncoderLayer, self).__init__() + self.x_norm1 = norm_layer(dim) + self.c_norm1 = norm_layer(dim) + + self.attn = MultiHeadCrossAttention(dim, num_heads, attn_drop, proj_drop) + + self.x_norm2 = norm_layer(dim) + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop) + + self.drop1 = nn.Dropout(drop_path) + self.drop2 = nn.Dropout(drop_path) + + def forward(self, x, complement): + x = self.x_norm1(x) + complement = self.c_norm1(complement) + + x = x + self.drop1(self.attn(x, complement)) + x = x + self.drop2(self.mlp(self.x_norm2(x))) + return x + + + +class CrossTransformer(nn.Module): + def __init__(self, x_dim, c_dim, depth, num_heads, mlp_ratio =1., attn_drop=0., proj_drop=0., drop_path =0.): + super(CrossTransformer, self).__init__() + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + LinearProject(x_dim, c_dim, CrossTransformerEncoderLayer(c_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)), + LinearProject(c_dim, x_dim, CrossTransformerEncoderLayer(x_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)) + ])) + + def forward(self, x, complement): + for x_attn_complemnt, complement_attn_x in self.layers: + x = x_attn_complemnt(x, complement=complement) + x + complement = complement_attn_x(complement, complement=x) + complement + return x, complement + + + + + + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mtrans_net.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mtrans_net.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e80dadb45725603c7ff1da664a45533575db25 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mtrans_net.py @@ -0,0 +1,145 @@ +""" +This script implement the paper "Multi-Modal Transformer for Accelerated MR Imaging (TMI 2022)" +""" + + +import torch + +from torch import nn +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler +from .mtrans_mca import CrossTransformer + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(INPUT_DIM, HEAD_HIDDEN_DIM): + return ReconstructionHead(INPUT_DIM, HEAD_HIDDEN_DIM) + + +class CrossCMMT(nn.Module): + + def __init__(self, args): + super(CrossCMMT, self).__init__() + + INPUT_SIZE = 240 + INPUT_DIM = 1 # the channel of input + OUTPUT_DIM = 1 # the channel of output + HEAD_HIDDEN_DIM = 16 # the hidden dim of Head + TRANSFORMER_DEPTH = 4 # the depth of the transformer + TRANSFORMER_NUM_HEADS = 4 # the head's num of multi head attention + TRANSFORMER_MLP_RATIO = 3 # the MLP RATIO Of transformer + TRANSFORMER_EMBED_DIM = 256 # the EMBED DIM of transformer + P1 = 8 + P2 = 16 + CTDEPTH = 4 + + + self.head = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + self.head2 = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + + x_patch_dim = HEAD_HIDDEN_DIM * P1 ** 2 + x_num_patches = (INPUT_SIZE // P1) ** 2 + + complement_patch_dim = HEAD_HIDDEN_DIM * P2 ** 2 + complement_num_patches = (INPUT_SIZE // P2) ** 2 + + self.x_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P1, + p2=P1), + ) + + self.complement_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P2, + p2=P2), + ) + + self.x_pos_embedding = nn.Parameter(torch.randn(1, x_num_patches, x_patch_dim)) + self.complement_pos_embedding = nn.Parameter(torch.randn(1, complement_num_patches, complement_patch_dim)) + + ### + self.cross_transformer = CrossTransformer(x_patch_dim, complement_patch_dim, CTDEPTH, + TRANSFORMER_NUM_HEADS, TRANSFORMER_MLP_RATIO) + + self.p1 = P1 + self.p2 = P2 + + self.tail1 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + self.tail2 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x, complement): + x = self.head(x) + complement = self.head2(complement) + + x = x + complement + + b, _, h, w = x.shape + + _, _, c_h, c_w = complement.shape + + ### 得到每个模态的patch embedding (加上了position encoding) + x = self.x_patch_embbeding(x) + x += self.x_pos_embedding + + complement = self.complement_patch_embbeding(complement) + complement += self.complement_pos_embedding + + ### 将两个模态的patch embeddings送到cross transformer中提取特征。 + x, complement = self.cross_transformer(x, complement) + + c = int(x.shape[2] / (self.p1 * self.p1)) + H = int(h / self.p1) + W = int(w / self.p1) + + x = x.reshape(b, H, W, self.p1, self.p1, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w, ) + x = self.tail1(x) + + complement = complement.reshape(b, int(c_h/self.p2), int(c_w/self.p2), self.p2, self.p2, int(complement.shape[2]/self.p2/self.p2)) + complement = complement.permute(0, 5, 1, 3, 2, 4) + complement = complement.reshape(b, -1, c_h, c_w) + + complement = self.tail2(complement) + + return x, complement + + +def build_model(args): + return CrossCMMT(args) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mtrans_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mtrans_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..61314a6a81247b0740a1e98d078242b26ce73f90 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/mtrans_transformer.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(args, embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=args.MODEL.TRANSFORMER_DEPTH, + num_heads=args.MODEL.TRANSFORMER_NUM_HEADS, + mlp_ratio=args.MODEL.TRANSFORMER_MLP_RATIO) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/munet_concat_decomp.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/munet_concat_decomp.py new file mode 100644 index 0000000000000000000000000000000000000000..07c6f64804dd586927814801b167163f0b65f58f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/munet_concat_decomp.py @@ -0,0 +1,295 @@ +""" +Xiaohan Xing, 2023/06/25 +两个模态分别用CNN提取多个层级的特征, 每个层级的特征都经过concat之后得到fused feature, +然后每个模态都通过一层conv变换得到2C*h*w的特征, 其中C个channels作为common feature, 另C个channels作为specific features. +利用CDDFuse论文中的decomposition loss约束特征解耦。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.transform_layers1 = nn.ModuleList() + self.transform_layers2 = nn.ModuleList() + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.transform_layers1.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.transform_layers2.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + stack1_common, stack1_specific, stack2_common, stack2_specific = [], [], [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_transform_layer, net2_transform_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.transform_layers1, self.transform_layers2, \ + self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + # print("original feature:", output1.shape) + + ### 分别提取两个模态的common feature和specific feature + output1 = net1_transform_layer(output1) + # print("transformed feature:", output1.shape) + output1_common = output1[:, :(output1.shape[1]//2), :, :] + output1_specific = output1[:, (output1.shape[1]//2):, :, :] + output2 = net2_transform_layer(output2) + output2_common = output2[:, :(output2.shape[1]//2), :, :] + output2_specific = output2[:, (output2.shape[1]//2):, :, :] + + stack1_common.append(output1_common) + stack1_specific.append(output1_specific) + stack2_common.append(output2_common) + stack2_specific.append(output2_specific) + + ### 将一个模态的两组feature和另一模态的common feature合并, 然后经过conv变换得到和初始特征相同维度的fused feature送到下一层。 + output1 = net1_fuse_layer(torch.cat((output1_common, output1_specific, output2_common), 1)) + output2 = net2_fuse_layer(torch.cat((output2_common, output2_specific, output1_common), 1)) + # print("fused feature:", output1.shape) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, stack1_common, stack1_specific, stack2_common, stack2_specific + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/munet_multi_concat.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/munet_multi_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..9e90373e1134210809b98fdc64e3d51d76123855 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/munet_multi_concat.py @@ -0,0 +1,306 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + # output1, output2 = image1, image2 + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # relation1 = get_relation_matrix(output1) + # relation2 = get_relation_matrix(output2) + # relation_stack1.append(relation1) + # relation_stack2.append(relation2) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + # output1 = net1_layer(output1) + # output2 = net2_layer(output2) + + # ### 两个模态的特征concat + # fused_output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + # fused_output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # stack1.append(fused_output1) + # stack2.append(fused_output2) + + # ### downsampling features + # output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + # output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/munet_multi_sum.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/munet_multi_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb4efb13b54ccbeaeb34fc51d3e72b0c50ee242 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/munet_multi_sum.py @@ -0,0 +1,270 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers): + + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + + ## 两个模态的特征求平均. + fused_output1 = (output1 + output2)/2.0 + fused_output2 = (output1 + output2)/2.0 + + output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features #, relation_stack1, relation_stack2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/munet_multi_transfuse.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/munet_multi_transfuse.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ce250a078ba0847c729c1ef1b76452f64d0f07 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/munet_multi_transfuse.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +2023/06/23 +每个层级的特征用TransFuse layer融合之后, 将两个模态的original features和transfuse features concat, +然后在每个模态中经过conv层变换得到下一层的输入。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .TransFuse import TransFuse_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_TransFuse(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + ### conv layer to transform the features after Transfuse layer. + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + self.n_anchors = 15 + self.avgpool = nn.AdaptiveAvgPool2d((self.n_anchors, self.n_anchors)) + self.transfuse_layer1 = TransFuse_layer(n_embd=self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer2 = TransFuse_layer(n_embd=2*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer3 = TransFuse_layer(n_embd=4*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer4 = TransFuse_layer(n_embd=8*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + + self.transfuse_layers = nn.Sequential(self.transfuse_layer1, self.transfuse_layer2, self.transfuse_layer3, self.transfuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, transfuse_layer, net1_fuse_layer, net2_fuse_layer in zip(self.encoder1.down_sample_layers, \ + self.encoder2.down_sample_layers, self.transfuse_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### extract multi-level features from the encoder of each modality. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + ### Transformer-based multi-modal fusion layer + feature1 = self.avgpool(output1) + feature2 = self.avgpool(output2) + feature1, feature2 = transfuse_layer(feature1, feature2) + feature1 = F.interpolate(feature1, scale_factor=output1.shape[-1]//self.n_anchors, mode='bilinear') + feature2 = F.interpolate(feature2, scale_factor=output2.shape[-1]//self.n_anchors, mode='bilinear') + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_TransFuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/munet_swinfusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/munet_swinfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6203c4331744f1da70f18e403360196a419b4cb5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/munet_swinfusion.py @@ -0,0 +1,268 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .SwinFuse_layer import SwinFusion_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_SwinFuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过SwinFusion的方式融合。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.fuse_type = 'swinfuse' + self.img_size = 240 + # self.fuse_type = 'sum' + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # Swin transformer fusion modules + + self.fuse_layer1 = SwinFusion_layer(img_size=(self.img_size//2, self.img_size//2), patch_size=4, window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer2 = SwinFusion_layer(img_size=(self.img_size//4, self.img_size//4), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=2*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer3 = SwinFusion_layer(img_size=(self.img_size//8, self.img_size//8), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=4*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer4 = SwinFusion_layer(img_size=(self.img_size//16, self.img_size//16), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=8*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.swinfusion_layers = nn.Sequential(self.fuse_layer1, self.fuse_layer2, self.fuse_layer3, self.fuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, swinfuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.swinfusion_layers): + output1 = net1_layer(output1) + stack1.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # print("output of encoders:", output1.shape, output2.shape) + # print("CNN feature range:", output1.max(), output2.min()) + + ### Swin transformer-based multi-modal fusion. + feature1, feature2 = swinfuse_layer(output1, output2) + output1 = (output1 + feature1 + output2 + feature2)/4.0 + output2 = (output1 + feature1 + output2 + feature2)/4.0 + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_SwinFuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/original_MINet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/original_MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..5a115872320f677d8b34264eed95c22fff0df9c4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/original_MINet.py @@ -0,0 +1,335 @@ +from fastmri.models import common +import torch +import torch.nn as nn +import torch.nn.functional as F + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=common.default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 2 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + common.Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.add_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class MINet(nn.Module): + def __init__(self, n_resgroups,n_resblocks, n_feats): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1,x2#x1=pd x2=pdfs + \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..3699cc809e438be4e1bd5efaba7d96e18d7f1602 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/restormer.py @@ -0,0 +1,307 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + # print("feature before FFN projection:", x.max(), x.min()) + x = self.project_out(x) + # print("FFN output:", x.max(), x.min()) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + b,c,h,w = x.shape + # print("Attention block input:", x.max(), x.min()) + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + # print("attention values:", attn) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("Attention block output:", out.max(), out.min()) + + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): + super(TransformerBlock, self).__init__() + + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + # dim = 32, + # num_blocks = [4,4,4,4], + dim = 32, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + + # stack = [] + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + # print("encoder block1 feature:", out_enc_level1.shape, out_enc_level1.max(), out_enc_level1.min()) + # stack.append(out_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + # print("encoder block2 feature:", out_enc_level2.shape, out_enc_level2.max(), out_enc_level2.min()) + # stack.append(out_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + # print("encoder block3 feature:", out_enc_level3.shape, out_enc_level3.max(), out_enc_level3.min()) + # stack.append(out_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + # print("encoder block4 feature:", latent.shape, latent.max(), latent.min()) + # stack.append(latent) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + + return out_dec_level1 #, stack + + + +def build_model(args): + return Restormer() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/restormer_block.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/restormer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..170ae6a826e5a3e356cb48f8c89500c2d5242816 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/restormer_block.py @@ -0,0 +1,321 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + # print("input to the LayerNorm layer:", x.max(), x.min()) + h, w = x.shape[-2:] + x = to_4d(self.body(to_3d(x)), h, w) + # print("output of the LayerNorm:", x.max(), x.min()) + return x + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + # self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.temperature = 1.0 / ((dim//num_heads) ** 0.5) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + # print("input to the Attention layer:", x.max(), x.min()) + + b,c,h,w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + # print("q range:", q.max(), q.min()) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + # print("scaling parameter:", self.temperature) + # print("attention matrix before softmax:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + attn = attn.softmax(dim=-1) + # print("attention matrix range:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + # print("v range:", v.max(), v.min(), v.mean()) + + out = (attn @ v) + # print("self-attention output:", out.max(), out.min()) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("output of attention layer:", out.max(), out.min()) + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, pre_norm): + super(TransformerBlock, self).__init__() + + self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + self.pre_norm = pre_norm + + def forward(self, x): + x = self.conv(x) + # print("restormer block input:", x.max(), x.min()) + if self.pre_norm: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + else: + x = self.norm1(x + self.attn(x)) + x = self.norm2(x + self.ffn(x)) + # x = self.norm1(self.attn(x)) + # x = self.norm2(self.ffn(x)) + # x = x + self.attn(x) + # x = x + self.ffn(x) + # print("restormer block output:", x.max(), x.min()) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + dim = 48, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, inp_img): + + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + + + return out_dec_level1 + + + + +if __name__ == '__main__': + model = Restormer() + # print(model) + + A = torch.randn((1, 1, 240, 240)) + x = model(A) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/swinIR.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/swinIR.py new file mode 100644 index 0000000000000000000000000000000000000000..5cbfc9c175c02531ac80a05ba4abfc0776f752dd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/swinIR.py @@ -0,0 +1,874 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 4, 4, 6], + embed_dim=32, num_heads=[6, 6, 6, 6], mlp_ratio=4, upsampler='pixelshuffledirect') + + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = SwinIR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/swin_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..877bdfbffc91012e4ac31f41b838ebb116f4a825 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/swin_transformer.py @@ -0,0 +1,878 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=1, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + # print("shallow feature:", x.shape) + x = self.conv_after_body(self.forward_features(x)) + x + # print("deep feature:", x.shape) + x = self.upsample(x) + # print("after upsample:", x.shape) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + # return SwinIR(upscale=1, img_size=(240, 240), + # window_size=8, img_range=1., depths=[6, 6, 6, 6], + # embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + +if __name__ == '__main__': + height = 240 + width = 240 + window_size = 8 + model = SwinIR(upscale=1, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + x = torch.randn((1, 1, height, width)) + print("input:", x.shape) + x = model(x) + print("output:", x.shape) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/trans_unet.zip b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/trans_unet.zip new file mode 100644 index 0000000000000000000000000000000000000000..e7b71c1287551fd6119efd7c62da2196307998f5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/trans_unet.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:369de2a28036532c446409ee1f7b953d5763641ffd452675613e1bb9bba89eae +size 19354 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/trans_unet/trans_unet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/trans_unet/trans_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..fa61dbadfcaee9b26594307adc90cdd1d2a84e3e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/trans_unet/trans_unet.py @@ -0,0 +1,489 @@ +# coding=utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import logging +import math + +from os.path import join as pjoin + +import torch +import torch.nn as nn +import numpy as np + +from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm +from torch.nn.modules.utils import _pair +from scipy import ndimage +from . import vit_seg_configs as configs +# import .vit_seg_configs as configs +from .vit_seg_modeling_resnet_skip import ResNetV2 + + +logger = logging.getLogger(__name__) + + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class Attention(nn.Module): + def __init__(self, config, vis): + super(Attention, self).__init__() + self.vis = vis + self.num_attention_heads = config.transformer["num_heads"] + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(config.hidden_size, self.all_head_size) + self.key = Linear(config.hidden_size, self.all_head_size) + self.value = Linear(config.hidden_size, self.all_head_size) + + self.out = Linear(config.hidden_size, config.hidden_size) + self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) + self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + + +class Mlp(nn.Module): + def __init__(self, config): + super(Mlp, self).__init__() + self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) + self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(config.transformer["dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings. + """ + def __init__(self, config, img_size, in_channels=3): + super(Embeddings, self).__init__() + self.hybrid = None + self.config = config + img_size = _pair(img_size) + + if config.patches.get("grid") is not None: # ResNet + grid_size = config.patches["grid"] + # print("image size:", img_size, "grid size:", grid_size) + patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) + patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) + # print("patch_size_real", patch_size_real) + n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) + self.hybrid = True + else: + patch_size = _pair(config.patches["size"]) + n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) + self.hybrid = False + + if self.hybrid: + self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) + in_channels = self.hybrid_model.width * 16 + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=config.hidden_size, + kernel_size=patch_size, + stride=patch_size) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) + + self.dropout = Dropout(config.transformer["dropout_rate"]) + + + def forward(self, x): + # print("input x:", x.shape) ### (4, 3, 240, 240) + if self.hybrid: + x, features = self.hybrid_model(x) + else: + features = None + # print("output of the hybrid model:", x.shape) ### (4, 1024, 15, 15) + x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) + x = x.flatten(2) + x = x.transpose(-1, -2) # (B, n_patches, hidden) + + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings, features + + +class Block(nn.Module): + def __init__(self, config, vis): + super(Block, self).__init__() + self.hidden_size = config.hidden_size + self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn = Mlp(config) + self.attn = Attention(config, vis) + + def forward(self, x): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + def load_from(self, weights, n_block): + ROOT = f"Transformer/encoderblock_{n_block}" + with torch.no_grad(): + query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() + key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() + value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() + out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() + + query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) + key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) + value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) + out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) + + self.attn.query.weight.copy_(query_weight) + self.attn.key.weight.copy_(key_weight) + self.attn.value.weight.copy_(value_weight) + self.attn.out.weight.copy_(out_weight) + self.attn.query.bias.copy_(query_bias) + self.attn.key.bias.copy_(key_bias) + self.attn.value.bias.copy_(value_bias) + self.attn.out.bias.copy_(out_bias) + + mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() + mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() + mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() + mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() + + self.ffn.fc1.weight.copy_(mlp_weight_0) + self.ffn.fc2.weight.copy_(mlp_weight_1) + self.ffn.fc1.bias.copy_(mlp_bias_0) + self.ffn.fc2.bias.copy_(mlp_bias_1) + + self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) + self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) + self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) + self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) + + +class Encoder(nn.Module): + def __init__(self, config, vis): + super(Encoder, self).__init__() + self.vis = vis + self.layer = nn.ModuleList() + self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) + for _ in range(config.transformer["num_layers"]): + layer = Block(config, vis) + self.layer.append(copy.deepcopy(layer)) + + def forward(self, hidden_states): + attn_weights = [] + for layer_block in self.layer: + hidden_states, weights = layer_block(hidden_states) + if self.vis: + attn_weights.append(weights) + encoded = self.encoder_norm(hidden_states) + return encoded, attn_weights + + +class Transformer(nn.Module): + def __init__(self, config, img_size, vis): + super(Transformer, self).__init__() + self.embeddings = Embeddings(config, img_size=img_size) + self.encoder = Encoder(config, vis) + # print(self.encoder) + + def forward(self, input_ids): + embedding_output, features = self.embeddings(input_ids) ### "features" store the features from different layers. + # print("embedding_output:", embedding_output.shape) ## (B, n_patch, hidden) + encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) + # print("encoded feature:", encoded.shape) + return encoded, attn_weights, features + + # return embedding_output, features + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + bn = nn.BatchNorm2d(out_channels) + + super(Conv2dReLU, self).__init__(conv, bn, relu) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels=0, + use_batchnorm=True, + ): + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + + def forward(self, x, skip=None): + x = self.up(x) + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class ReconstructionHead(nn.Sequential): + + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() + super().__init__(conv2d, upsampling) + + +class DecoderCup(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + head_channels = 512 + self.conv_more = Conv2dReLU( + config.hidden_size, + head_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + decoder_channels = config.decoder_channels + in_channels = [head_channels] + list(decoder_channels[:-1]) + out_channels = decoder_channels + + if self.config.n_skip != 0: + skip_channels = self.config.skip_channels + for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip + skip_channels[3-i]=0 + + else: + skip_channels=[0,0,0,0] + + blocks = [ + DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) + ] + self.blocks = nn.ModuleList(blocks) + # print(self.blocks) + + def forward(self, hidden_states, features=None): + B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) + h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) + x = hidden_states.permute(0, 2, 1) + x = x.contiguous().view(B, hidden, h, w) + x = self.conv_more(x) + for i, decoder_block in enumerate(self.blocks): + if features is not None: + skip = features[i] if (i < self.config.n_skip) else None + else: + skip = None + x = decoder_block(x, skip=skip) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): + super(VisionTransformer, self).__init__() + self.num_classes = num_classes + self.zero_head = zero_head + self.classifier = config.classifier + self.transformer = Transformer(config, img_size, vis) + self.decoder = DecoderCup(config) + self.segmentation_head = ReconstructionHead( + in_channels=config['decoder_channels'][-1], + out_channels=config['n_classes'], + kernel_size=3, + ) + self.activation = nn.Tanh() + self.config = config + + def forward(self, x): + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + ### hybrid CNN and transformer to extract features from the input images. + x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) + + # ### extract features with CNN only. + # x, features = self.transformer(x) # (B, n_patch, hidden) + + x = self.decoder(x, features) + # logits = self.segmentation_head(x) + logits = self.activation(self.segmentation_head(x)) + # print("logits:", logits.shape) + return logits + + def load_from(self, weights): + with torch.no_grad(): + + res_weight = weights + self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) + self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) + + self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) + self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) + + posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) + + posemb_new = self.transformer.embeddings.position_embeddings + if posemb.size() == posemb_new.size(): + self.transformer.embeddings.position_embeddings.copy_(posemb) + elif posemb.size()[1]-1 == posemb_new.size()[1]: + posemb = posemb[:, 1:] + self.transformer.embeddings.position_embeddings.copy_(posemb) + else: + logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) + ntok_new = posemb_new.size(1) + if self.classifier == "seg": + _, posemb_grid = posemb[:, :1], posemb[0, 1:] + gs_old = int(np.sqrt(len(posemb_grid))) + gs_new = int(np.sqrt(ntok_new)) + print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb = posemb_grid + self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) + + # Encoder whole + for bname, block in self.transformer.encoder.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=uname) + + if self.transformer.embeddings.hybrid: + self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) + gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) + gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) + self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) + self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) + + for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(res_weight, n_block=bname, n_unit=uname) + +CONFIGS = { + 'ViT-B_16': configs.get_b16_config(), + 'ViT-B_32': configs.get_b32_config(), + 'ViT-L_16': configs.get_l16_config(), + 'ViT-L_32': configs.get_l32_config(), + 'ViT-H_14': configs.get_h14_config(), + 'R50-ViT-B_16': configs.get_r50_b16_config(), + 'R50-ViT-L_16': configs.get_r50_l16_config(), + 'testing': configs.get_testing(), +} + + + +def build_model(args): + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + print(net) + # net.load_from(weights=np.load(config_vit.pretrained_path)) + return net + + +if __name__ == "__main__": + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + net.load_from(weights=np.load(config_vit.pretrained_path)) + + input_data = torch.rand(4, 1, 240, 240).cuda() + print("input:", input_data.shape) + output = net(input_data) + print("output:", output.shape) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/trans_unet/vit_seg_configs.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/trans_unet/vit_seg_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..bed7d3346e30328a38ac6fef1621f788d905df1e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/trans_unet/vit_seg_configs.py @@ -0,0 +1,132 @@ +import ml_collections + +def get_b16_config(): + """Returns the ViT-B/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 768 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 3072 + config.transformer.num_heads = 12 + config.transformer.num_layers = 4 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + + config.classifier = 'seg' + config.representation_size = None + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' + config.patch_size = 16 + + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_testing(): + """Returns a minimal configuration for testing.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 1 + config.transformer.num_heads = 1 + config.transformer.num_layers = 1 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config + + +def get_r50_b16_config(): + """Returns the Resnet50 + ViT-B/16 configuration.""" + config = get_b16_config() + # config.patches.grid = (16, 16) + config.patches.grid = (15, 15) ### for the input image with size (240, 240) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'recon' + config.pretrained_path = '/home/xiaohan/workspace/MSL_MRI/code/pretrained_model/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 1 + config.n_skip = 3 + config.activation = 'tanh' + + return config + + +def get_b32_config(): + """Returns the ViT-B/32 configuration.""" + config = get_b16_config() + config.patches.size = (32, 32) + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' + return config + + +def get_l16_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1024 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 4096 + config.transformer.num_heads = 16 + config.transformer.num_layers = 24 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.representation_size = None + + # custom + config.classifier = 'seg' + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_r50_l16_config(): + """Returns the Resnet50 + ViT-L/16 configuration. customized """ + config = get_l16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_l32_config(): + """Returns the ViT-L/32 configuration.""" + config = get_l16_config() + config.patches.size = (32, 32) + return config + + +def get_h14_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (14, 14)}) + config.hidden_size = 1280 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 5120 + config.transformer.num_heads = 16 + config.transformer.num_layers = 32 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + + return config \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae80ef7d96c46cb91bda849901aef34008e62ed --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py @@ -0,0 +1,160 @@ +import math + +from os.path import join as pjoin +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +class StdConv2d(nn.Conv2d): + + def forward(self, x): + w = self.weight + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / torch.sqrt(v + 1e-5) + return F.conv2d(x, w, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +def conv3x3(cin, cout, stride=1, groups=1, bias=False): + return StdConv2d(cin, cout, kernel_size=3, stride=stride, + padding=1, bias=bias, groups=groups) + + +def conv1x1(cin, cout, stride=1, bias=False): + return StdConv2d(cin, cout, kernel_size=1, stride=stride, + padding=0, bias=bias) + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + """ + + def __init__(self, cin, cout=None, cmid=None, stride=1): + super().__init__() + cout = cout or cin + cmid = cmid or cout//4 + + self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv1 = conv1x1(cin, cmid, bias=False) + self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! + self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) + self.conv3 = conv1x1(cmid, cout, bias=False) + self.relu = nn.ReLU(inplace=True) + + if (stride != 1 or cin != cout): + # Projection also with pre-activation according to paper. + self.downsample = conv1x1(cin, cout, stride, bias=False) + self.gn_proj = nn.GroupNorm(cout, cout) + + def forward(self, x): + + # Residual branch + residual = x + if hasattr(self, 'downsample'): + residual = self.downsample(x) + residual = self.gn_proj(residual) + + # Unit's branch + y = self.relu(self.gn1(self.conv1(x))) + y = self.relu(self.gn2(self.conv2(y))) + y = self.gn3(self.conv3(y)) + + y = self.relu(residual + y) + return y + + def load_from(self, weights, n_block, n_unit): + conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) + conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) + conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) + + gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) + gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) + + gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) + gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) + + gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) + gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) + + self.conv1.weight.copy_(conv1_weight) + self.conv2.weight.copy_(conv2_weight) + self.conv3.weight.copy_(conv3_weight) + + self.gn1.weight.copy_(gn1_weight.view(-1)) + self.gn1.bias.copy_(gn1_bias.view(-1)) + + self.gn2.weight.copy_(gn2_weight.view(-1)) + self.gn2.bias.copy_(gn2_bias.view(-1)) + + self.gn3.weight.copy_(gn3_weight.view(-1)) + self.gn3.bias.copy_(gn3_bias.view(-1)) + + if hasattr(self, 'downsample'): + proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) + proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) + proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) + + self.downsample.weight.copy_(proj_conv_weight) + self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) + self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode.""" + + def __init__(self, block_units, width_factor): + super().__init__() + width = int(64 * width_factor) + self.width = width + + self.root = nn.Sequential(OrderedDict([ + ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), + ('gn', nn.GroupNorm(32, width, eps=1e-6)), + ('relu', nn.ReLU(inplace=True)), + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) + ])) + + self.body = nn.Sequential(OrderedDict([ + ('block1', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], + ))), + ('block2', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], + ))), + ('block3', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], + ))), + ])) + + def forward(self, x): + features = [] + b, c, in_size, _ = x.size() + x = self.root(x) + features.append(x) + x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) + for i in range(len(self.body)-1): + x = self.body[i](x) + right_size = int(in_size / 4 / (i+1)) + if x.size()[2] != right_size: + pad = right_size - x.size()[2] + assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) + feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) + feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] + else: + feat = x + features.append(feat) + x = self.body[-1](x) + return x, features[::-1] \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/transformer_modules.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/transformer_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..900042884305619e39ad9b792b3b1936dfbd37b3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/transformer_modules.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=4, + num_heads=4, + mlp_ratio=3) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/unet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..d6dd935cd34a121d8079a8f5184d4ea29e15c8da --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/unet.py @@ -0,0 +1,196 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F + + +class Unet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = image + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + # print("unet feature range:", output.max(), output.min()) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/unet_restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/unet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f78aa5f192957b2706c6f691dfc66c427b65ae --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/unet_restormer.py @@ -0,0 +1,234 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + # print("Unet feature range:", output.max(), output.min()) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/unet_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/unet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..982fe23cec321de6b636e04c51b62037054ea432 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/unet_transformer.py @@ -0,0 +1,346 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +对于Unet中的high-level feature, 用Transformer进行特征交互,然后和Transformer之前的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Unet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = image + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack, output = self.encoder(output) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output) # [4, 225, 256], 225 is the number of patches. + patch_embed_input = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1])), int(np.sqrt(feature_output.shape[1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[-1], h, w) # [4,512*2,24,24] + + output = torch.cat((output, feature_output), 1) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output = self.decoder(output, stack) + + return output + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Transformer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/unimodal_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/unimodal_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c99214e10bee2f7a1be49f8560799ffbc14ed0a5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/compare_models/unimodal_transformer.py @@ -0,0 +1,112 @@ +import torch + +from torch import nn +from .transformer_modules import build_transformer +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(): + return ReconstructionHead(input_dim=1, hidden_dim=12) + + + +class CMMT(nn.Module): + + def __init__(self, args): + super(CMMT, self).__init__() + + self.head = build_head() + + HEAD_HIDDEN_DIM = 12 + PATCH_SIZE = 16 + INPUT_SIZE = 240 + OUTPUT_DIM = 1 + + patch_dim = HEAD_HIDDEN_DIM* PATCH_SIZE ** 2 + num_patches = (INPUT_SIZE // PATCH_SIZE) ** 2 + + self.transformer = build_transformer(patch_dim) + + self.patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=PATCH_SIZE, p2=PATCH_SIZE), + ) + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, patch_dim)) + + + self.p1 = PATCH_SIZE + self.p2 = PATCH_SIZE + + self.tail = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x): + + b,_ , h, w = x.shape + + x = self.head(x) + + x= self.patch_embbeding(x) + + x += self.pos_embedding + + x = self.transformer(x) # b HW p1p2c + + c = int(x.shape[2]/(self.p1*self.p2)) + H = int(h/self.p1) + W = int(w/self.p2) + + x = x.reshape(b, H, W, self.p1, self.p2, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w,) + x = self.tail(x) + + return x + + +def build_model(args): + return CMMT(args) + + + + + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/modules.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..bad92a6d6709130c4ad1cc49d5094958d44d7e71 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/modules.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + +USE_PYTORCH_IN = False + + +###################################################################### +# Superclass of all Modules that take two inputs +###################################################################### +class TwoInputModule(nn.Module): + def forward(self, input1, input2): + raise NotImplementedError + +###################################################################### +# A (sort of) hacky way to create a module that takes two inputs (e.g. x and z) +# and returns one output (say o) defined as follows: +# o = module2.forward(module1.forward(x), z) +# Note that module2 MUST support two inputs as well. +###################################################################### +class MergeModule(TwoInputModule): + def __init__(self, module1, module2): + """ module1 could be any module (e.g. Sequential of several modules) + module2 must accept two inputs + """ + super(MergeModule, self).__init__() + self.module1 = module1 + self.module2 = module2 + + def forward(self, input1, input2): + output1 = self.module1.forward(input1) + output2 = self.module2.forward(output1, input2) + return output2 + +###################################################################### +# A (sort of) hacky way to create a container that takes two inputs (e.g. x and z) +# and applies a sequence of modules (exactly like nn.Sequential) but MergeModule +# is one of its submodules it applies it to both inputs +###################################################################### +class TwoInputSequential(nn.Sequential, TwoInputModule): + def __init__(self, *args): + super(TwoInputSequential, self).__init__(*args) + + def forward(self, input1, input2): + """overloads forward function in parent calss""" + + for module in self._modules.values(): + if isinstance(module, TwoInputModule): + input1 = module.forward(input1, input2) + else: + input1 = module.forward(input1) + return input1 + + +###################################################################### +# A standard instance norm module. +# Since the pytorch instance norm used BatchNorm as a base and thus is +# different from the standard implementation. +###################################################################### +class InstanceNorm(nn.Module): + def __init__(self, num_features, affine=True, eps=1e-5): + """`num_features` number of feature channels + """ + super(InstanceNorm, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + + self.scale = Parameter(torch.Tensor(num_features)) + self.shift = Parameter(torch.Tensor(num_features)) + + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + self.scale.data.normal_(mean=0., std=0.02) + self.shift.data.zero_() + + def forward(self, input): + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + centered_x = x_reshaped - mean + std = torch.rsqrt((centered_x ** 2).mean(2, keepdim=True) + self.eps) + norm_features = (centered_x * std).view(*size) + + # broadcast on the batch dimension, hight and width dimensions + if self.affine: + output = norm_features * self.scale[:,None,None] + self.shift[:,None,None] + else: + output = norm_features + + return output + +InstanceNorm2d = nn.InstanceNorm2d if USE_PYTORCH_IN else InstanceNorm + +###################################################################### +# A module implementing conditional instance norm. +# Takes two inputs: x (input features) and z (latent codes) +###################################################################### +class CondInstanceNorm(TwoInputModule): + def __init__(self, x_dim, z_dim, eps=1e-5): + """`x_dim` dimensionality of x input + `z_dim` dimensionality of z latents + """ + super(CondInstanceNorm, self).__init__() + self.eps = eps + self.shift_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + self.scale_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + + def forward(self, input, noise): + + shift = self.shift_conv.forward(noise) + scale = self.scale_conv.forward(noise) + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + var = x_reshaped.var(2, keepdim=True) + std = torch.rsqrt(var + self.eps) + norm_features = ((x_reshaped - mean) * std).view(*size) + output = norm_features * scale + shift + return output + + +###################################################################### +# A modified resnet block which allows for passing additional noise input +# to be used for conditional instance norm +###################################################################### +class CINResnetBlock(TwoInputModule): + def __init__(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + super(CINResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + for idx, module in enumerate(self.conv_block): + self.add_module(str(idx), module) + + def build_conv_block(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [ + MergeModule( + nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(x_dim, z_dim) + ), + nn.ReLU(True) + ] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + InstanceNorm2d(x_dim, affine=True)] + + return TwoInputSequential(*conv_block) + + def forward(self, x, noise): + out = self.conv_block(x, noise) + out = self.relu(x + out) + return out + +###################################################################### +# Define a resnet block +###################################################################### +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [nn.ReLU(True)] + + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = self.conv_block(x) + out = self.relu(x + out) + return out diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/mynet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/mynet.py new file mode 100644 index 0000000000000000000000000000000000000000..91c2966b09a9261e23582c29093b9e59ebd0d4be --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks/mynet.py @@ -0,0 +1,389 @@ +import torch +from torch import nn +from networks import common_freq as common + + +class TwoBranch(nn.Module): + def __init__(self, args): + super(TwoBranch, self).__init__() + + num_group = 4 + num_every_group = args.base_num_every_group + + self.args = args + + self.init_T2_frq_branch(args) + self.init_T2_spa_branch(args, num_every_group) + self.init_T2_fre_spa_fusion(args) + + self.init_T1_frq_branch(args) + self.init_T1_spa_branch(args, num_every_group) + + self.init_modality_fre_fusion(args) + self.init_modality_spa_fusion(args) + + + def init_T2_frq_branch(self, args): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head_fre = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + + self.down1_fre = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down2_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down2_fre = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down3_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down3_fre = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_neck_fre = [common.FreBlock9(args.num_features, args) + ] + self.neck_fre = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_up1_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up1_fre = nn.Sequential(*modules_up1_fre) + self.up1_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_up2_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up2_fre = nn.Sequential(*modules_up2_fre) + self.up2_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_up3_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up3_fre = nn.Sequential(*modules_up3_fre) + self.up3_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + # define tail module + modules_tail_fre = [ + common.ConvBNReLU2D(args.num_features, out_channels=args.num_channels, kernel_size=3, padding=1, + act=args.act)] + self.tail_fre = nn.Sequential(*modules_tail_fre) + + def init_T2_spa_branch(self, args, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down1 = nn.Sequential(*modules_down1) + + + self.down1_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down2 = nn.Sequential(*modules_down2) + + self.down2_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down3 = nn.Sequential(*modules_down3) + self.down3_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.neck = nn.Sequential(*modules_neck) + + self.neck_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_up1 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.up1 = nn.Sequential(*modules_up1) + + self.up1_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_up2 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.up2 = nn.Sequential(*modules_up2) + self.up2_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + + modules_up3 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.up3 = nn.Sequential(*modules_up3) + self.up3_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + # define tail module + modules_tail = [ + common.ConvBNReLU2D(args.num_features, out_channels=args.num_channels, kernel_size=3, padding=1, + act=args.act)] + + self.tail = nn.Sequential(*modules_tail) + + def init_T2_fre_spa_fusion(self, args): + ### T2 frq & spa fusion part + conv_fuse = [] + for i in range(14): + conv_fuse.append(common.FuseBlock7(args.num_features)) + self.conv_fuse = nn.Sequential(*conv_fuse) + + def init_T1_frq_branch(self, args): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head_fre_T1 = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + + self.down1_fre_T1 = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down2_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down2_fre_T1 = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down3_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down3_fre_T1 = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_neck_fre = [common.FreBlock9(args.num_features, args) + ] + self.neck_fre_T1 = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + def init_T1_spa_branch(self, args, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head_T1 = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down1_T1 = nn.Sequential(*modules_down1) + + + self.down1_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down2_T1 = nn.Sequential(*modules_down2) + + self.down2_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down3_T1 = nn.Sequential(*modules_down3) + self.down3_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.neck_T1 = nn.Sequential(*modules_neck) + + self.neck_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + + def init_modality_fre_fusion(self, args): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(args.num_features)) + self.conv_fuse_fre = nn.Sequential(*conv_fuse) + + def init_modality_spa_fusion(self, args): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(args.num_features)) + self.conv_fuse_spa = nn.Sequential(*conv_fuse) + + def forward(self, main, aux): + #### T1 fre encoder + t1_fre = self.head_fre_T1(aux) # 128 + + down1_fre_t1 = self.down1_fre_T1(t1_fre)# 64 + down1_fre_mo_t1 = self.down1_fre_mo_T1(down1_fre_t1) + + down2_fre_t1 = self.down2_fre_T1(down1_fre_mo_t1) # 32 + down2_fre_mo_t1 = self.down2_fre_mo_T1(down2_fre_t1) + + down3_fre_t1 = self.down3_fre_T1(down2_fre_mo_t1) # 16 + down3_fre_mo_t1 = self.down3_fre_mo_T1(down3_fre_t1) + + neck_fre_t1 = self.neck_fre_T1(down3_fre_mo_t1) # 16 + neck_fre_mo_t1 = self.neck_fre_mo_T1(neck_fre_t1) + + + #### T2 fre encoder and T1 & T2 fre fusion + x_fre = self.head_fre(main) # 128 + x_fre_fuse = self.conv_fuse_fre[0](t1_fre, x_fre) + + down1_fre = self.down1_fre(x_fre_fuse)# 64 + down1_fre_mo = self.down1_fre_mo(down1_fre) + down1_fre_mo_fuse = self.conv_fuse_fre[1](down1_fre_mo_t1, down1_fre_mo) + + down2_fre = self.down2_fre(down1_fre_mo_fuse) # 32 + down2_fre_mo = self.down2_fre_mo(down2_fre) + down2_fre_mo_fuse = self.conv_fuse_fre[2](down2_fre_mo_t1, down2_fre_mo) + + down3_fre = self.down3_fre(down2_fre_mo_fuse) # 16 + down3_fre_mo = self.down3_fre_mo(down3_fre) + down3_fre_mo_fuse = self.conv_fuse_fre[3](down3_fre_mo_t1, down3_fre_mo) + + neck_fre = self.neck_fre(down3_fre_mo_fuse) # 16 + neck_fre_mo = self.neck_fre_mo(neck_fre) + neck_fre_mo_fuse = self.conv_fuse_fre[4](neck_fre_mo_t1, neck_fre_mo) + + + #### T2 fre decoder + neck_fre_mo = neck_fre_mo_fuse + down3_fre_mo_fuse + + up1_fre = self.up1_fre(neck_fre_mo) # 32 + up1_fre_mo = self.up1_fre_mo(up1_fre) + up1_fre_mo = up1_fre_mo + down2_fre_mo_fuse + + up2_fre = self.up2_fre(up1_fre_mo) # 64 + up2_fre_mo = self.up2_fre_mo(up2_fre) + up2_fre_mo = up2_fre_mo + down1_fre_mo_fuse + + up3_fre = self.up3_fre(up2_fre_mo) # 128 + up3_fre_mo = self.up3_fre_mo(up3_fre) + up3_fre_mo = up3_fre_mo + x_fre_fuse + + res_fre = self.tail_fre(up3_fre_mo) + + #### T1 spa encoder + x_t1 = self.head_T1(aux) # 128 + + down1_t1 = self.down1_T1(x_t1) # 64 + down1_mo_t1 = self.down1_mo_T1(down1_t1) + + down2_t1 = self.down2_T1(down1_mo_t1) # 32 + down2_mo_t1 = self.down2_mo_T1(down2_t1) # 32 + + down3_t1 = self.down3_T1(down2_mo_t1) # 16 + down3_mo_t1 = self.down3_mo_T1(down3_t1) # 16 + + neck_t1 = self.neck_T1(down3_mo_t1) # 16 + neck_mo_t1 = self.neck_mo_T1(neck_t1) + + #### T2 spa encoder and fusion + x = self.head(main) # 128 + + x_fuse = self.conv_fuse_spa[0](x_t1, x) + down1 = self.down1(x_fuse) # 64 + down1_fuse = self.conv_fuse[0](down1_fre, down1) + down1_mo = self.down1_mo(down1_fuse) + down1_fuse_mo = self.conv_fuse[1](down1_fre_mo_fuse, down1_mo) + + down1_fuse_mo_fuse = self.conv_fuse_spa[1](down1_mo_t1, down1_fuse_mo) + down2 = self.down2(down1_fuse_mo_fuse) # 32 + down2_fuse = self.conv_fuse[2](down2_fre, down2) + down2_mo = self.down2_mo(down2_fuse) # 32 + down2_fuse_mo = self.conv_fuse[3](down2_fre_mo, down2_mo) + + down2_fuse_mo_fuse = self.conv_fuse_spa[2](down2_mo_t1, down2_fuse_mo) + down3 = self.down3(down2_fuse_mo_fuse) # 16 + down3_fuse = self.conv_fuse[4](down3_fre, down3) + down3_mo = self.down3_mo(down3_fuse) # 16 + down3_fuse_mo = self.conv_fuse[5](down3_fre_mo, down3_mo) + + down3_fuse_mo_fuse = self.conv_fuse_spa[3](down3_mo_t1, down3_fuse_mo) + neck = self.neck(down3_fuse_mo_fuse) # 16 + neck_fuse = self.conv_fuse[6](neck_fre, neck) + neck_mo = self.neck_mo(neck_fuse) + neck_mo = neck_mo + down3_mo + neck_fuse_mo = self.conv_fuse[7](neck_fre_mo, neck_mo) + + neck_fuse_mo_fuse = self.conv_fuse_spa[4](neck_mo_t1, neck_fuse_mo) + #### T2 spa decoder + up1 = self.up1(neck_fuse_mo_fuse) # 32 + up1_fuse = self.conv_fuse[8](up1_fre, up1) + up1_mo = self.up1_mo(up1_fuse) + up1_mo = up1_mo + down2_mo + up1_fuse_mo = self.conv_fuse[9](up1_fre_mo, up1_mo) + + up2 = self.up2(up1_fuse_mo) # 64 + up2_fuse = self.conv_fuse[10](up2_fre, up2) + up2_mo = self.up2_mo(up2_fuse) + up2_mo = up2_mo + down1_mo + up2_fuse_mo = self.conv_fuse[11](up2_fre_mo, up2_mo) + + up3 = self.up3(up2_fuse_mo) # 128 + + up3_fuse = self.conv_fuse[12](up3_fre, up3) + up3_mo = self.up3_mo(up3_fuse) + + up3_mo = up3_mo + x + up3_fuse_mo = self.conv_fuse[13](up3_fre_mo, up3_mo) + # import matplotlib.pyplot as plt + # plt.axis('off') + # plt.imshow((255*up3_fre_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_fre_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + + # plt.axis('off') + # plt.imshow((255*up3_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + + # plt.axis('off') + # plt.imshow((255*up3_fuse_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_fuse_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + # breakpoint() + + res = self.tail(up3_fuse_mo) + + return {'img_out': res + main, 'img_fre': res_fre + main} + +def make_model(args): + return TwoBranch(args) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/__pycache__/__init__.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c15284be57bcc6b4485d437ab4679123e8008c6 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/__pycache__/__init__.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/__pycache__/common_freq.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/__pycache__/common_freq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64ad56ad674c4607224268d6e1d804890482e407 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/__pycache__/common_freq.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/__pycache__/mynet.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/__pycache__/mynet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08dc6e713b2007c43524d7b32a4837388f6c7ff6 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/__pycache__/mynet.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/common_freq.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/common_freq.py new file mode 100644 index 0000000000000000000000000000000000000000..559392e6252c5be1e8b94d4d3895771450160d67 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/common_freq.py @@ -0,0 +1,411 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange + +class Scale(nn.Module): + + def __init__(self, init_value=1e-3): + super().__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +class invPixelShuffle(nn.Module): + def __init__(self, ratio=2): + super(invPixelShuffle, self).__init__() + self.ratio = ratio + + def forward(self, tensor): + ratio = self.ratio + b = tensor.size(0) + ch = tensor.size(1) + y = tensor.size(2) + x = tensor.size(3) + assert x % ratio == 0 and y % ratio == 0, 'x, y, ratio : {}, {}, {}'.format(x, y, ratio) + tensor = tensor.view(b, ch, y // ratio, ratio, x // ratio, ratio).permute(0, 1, 3, 5, 2, 4) + return tensor.contiguous().view(b, -1, y // ratio, x // ratio) + + +class UpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(in_channels=n_feats, out_channels=4 * n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PixelShuffle(upscale_factor=2)) + m.append(nn.PReLU()) + super(UpSampler, self).__init__(*m) + + +class InvUpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(invPixelShuffle(2)) + m.append(nn.Conv2d(in_channels=n_feats * 4, out_channels=n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PReLU()) + super(InvUpSampler, self).__init__(*m) + + + +class AdaptiveNorm(nn.Module): + def __init__(self, n): + super(AdaptiveNorm, self).__init__() + + self.w_0 = nn.Parameter(torch.Tensor([1.0])) + self.w_1 = nn.Parameter(torch.Tensor([0.0])) + + self.bn = nn.BatchNorm2d(n, momentum=0.999, eps=0.001) + + def forward(self, x): + return self.w_0 * x + self.w_1 * self.bn(x) + + +class ConvBNReLU2D(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, + bias=False, act=None, norm=None, temp_channel=None): + super(ConvBNReLU2D, self).__init__() + + if not isinstance(temp_channel, type(None)): + self.temb_proj = torch.nn.Linear(temp_channel, out_channels) + + self.layers = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + self.act = None + self.norm = None + if norm == 'BN': + self.norm = torch.nn.BatchNorm2d(out_channels) + elif norm == 'IN': + self.norm = torch.nn.InstanceNorm2d(out_channels) + elif norm == 'GN': + self.norm = torch.nn.GroupNorm(2, out_channels) + elif norm == 'WN': + self.layers = torch.nn.utils.weight_norm(self.layers) + elif norm == 'Adaptive': + self.norm = AdaptiveNorm(n=out_channels) + + if act == 'PReLU': + self.act = torch.nn.PReLU() + elif act == 'SELU': + self.act = torch.nn.SELU(True) + elif act == 'LeakyReLU': + self.act = torch.nn.LeakyReLU(negative_slope=0.02, inplace=True) + elif act == 'ELU': + self.act = torch.nn.ELU(inplace=True) + elif act == 'ReLU': + self.act = torch.nn.ReLU(True) + elif act == 'Tanh': + self.act = torch.nn.Tanh() + elif act == 'Sigmoid': + self.act = torch.nn.Sigmoid() + elif act == 'SoftMax': + self.act = torch.nn.Softmax2d() + + def forward(self, inputs, temb=None): + + out = self.layers(inputs) + + if not isinstance(temb, type(None)): + out = out + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + if self.norm is not None: + out = self.norm(out) + if self.act is not None: + out = self.act(out) + return out + + +class ResBlock(nn.Module): + def __init__(self, num_features, temp_channel=None): + super(ResBlock, self).__init__() + self.layers1 = ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, act='ReLU', padding=1, temp_channel=temp_channel) + self.layers2 = ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, padding=1) + + + def forward(self, inputs, temp=None): + x = self.layers1(inputs, temp) + x = self.layers2(x) + + return F.relu(x + inputs) + + +class DownSample(nn.Module): + def __init__(self, num_features, act, norm, scale=2): + super(DownSample, self).__init__() + if scale == 1: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + else: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + invPixelShuffle(ratio=scale), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + + def forward(self, inputs): + return self.layers(inputs) + + +class ResidualGroup(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act,norm, n_resblocks, temp_channel=None): + super(ResidualGroup, self).__init__() + + self.head = ResBlock(n_feat, temp_channel) # Use to be two + + modules_body = [ResBlock(n_feat) for _ in range(n_resblocks - 1)] + + modules_body.append(ConvBNReLU2D(n_feat, n_feat, kernel_size, padding=1, act=act, norm=norm)) + self.body = nn.Sequential(*modules_body) + self.re_scale = Scale(1) + + def forward(self, x, t=None): + x = self.head(x, t) + res = self.body(x) + return res + self.re_scale(x) + + + +class FreBlock9(nn.Module): + def __init__(self, channels, args): + super(FreBlock9, self).__init__() + + self.temb_proj = torch.nn.Linear(args.temb_channels, channels) + + self.fpre = nn.Conv2d(channels, channels, 1, 1, 0) + self.amp_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.pha_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.post = nn.Conv2d(channels, channels, 1, 1, 0) + + + def forward(self, x, temb=None): + # print("x: ", x.shape) + _, _, H, W = x.shape + + if not isinstance(temb, type(None)): + x = x + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + msF = torch.fft.rfft2(self.fpre(x)+1e-8, norm='backward') + + msF_amp = torch.abs(msF) + msF_pha = torch.angle(msF) + # print("msf_amp: ", msF_amp.shape) + amp_fuse = self.amp_fuse(msF_amp) + # print(amp_fuse.shape, msF_amp.shape) + amp_fuse = amp_fuse + msF_amp + pha_fuse = self.pha_fuse(msF_pha) + pha_fuse = pha_fuse + msF_pha + + real = amp_fuse * torch.cos(pha_fuse)+1e-8 + imag = amp_fuse * torch.sin(pha_fuse)+1e-8 + out = torch.complex(real, imag)+1e-8 + out = torch.abs(torch.fft.irfft2(out, s=(H, W), norm='backward')) + out = self.post(out) + out = out + x + out = torch.nan_to_num(out, nan=1e-5, posinf=1e-5, neginf=1e-5) + # print("out: ", out.shape) + return out + +class Attention(nn.Module): + def __init__(self, dim=64, num_heads=8, bias=False): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) + self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias) + self.q = nn.Conv2d(dim, dim , kernel_size=1, bias=bias) + self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x, y): + b, c, h, w = x.shape + + kv = self.kv_dwconv(self.kv(y)) + k, v = kv.chunk(2, dim=1) + q = self.q_dwconv(self.q(x)) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + +class FuseBlock7(nn.Module): + def __init__(self, channels): + super(FuseBlock7, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fre_att = Attention(dim=channels) + self.spa_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + ori = spa + fre = self.fre(fre) + spa = self.spa(spa) + fre = self.fre_att(fre, spa)+fre + spa = self.fre_att(spa, fre)+spa + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class FuseBlock6(nn.Module): + def __init__(self, channels): + super(FuseBlock6, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock7(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock7, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t1_att = Attention(dim=channels) + self.t2_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + t1 = self.t1_att(t1, t2)+t1 + t2 = self.t2_att(t2, t1)+t2 + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock6(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock6, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/ARTUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/ARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ae23fed16b0d9454a5ea09f17b4322f1e9d7e928 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/ARTUnet.py @@ -0,0 +1,751 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + # print("x shape:", x.shape) + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +########################################################################## +class CS_TransformerBlock(nn.Module): + r""" 在ART Transformer Block的基础上添加了channel attention的分支。 + 参考论文: Dual Aggregation Transformer for Image Super-Resolution + 及其代码: https://github.com/zhengchen1999/DAT/blob/main/basicsr/archs/dat_arch.py + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + + self.dwconv = nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim), + nn.InstanceNorm2d(dim), + nn.GELU()) + + self.channel_interaction = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(dim, dim // 8, kernel_size=1), + # nn.InstanceNorm2d(dim // 8), + nn.GELU(), + nn.Conv2d(dim // 8, dim, kernel_size=1)) + + self.spatial_interaction = nn.Sequential( + nn.Conv2d(dim, dim // 16, kernel_size=1), + nn.InstanceNorm2d(dim // 16), + nn.GELU(), + nn.Conv2d(dim // 16, 1, kernel_size=1)) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # print("x before dwconv:", x.shape) + # convolution output + conv_x = self.dwconv(x.permute(0, 3, 1, 2)) + # conv_x = x.permute(0, 3, 1, 2) + # print("x after dwconv:", x.shape) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + attened_x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + attened_x = attened_x[:, :H, :W, :].contiguous() + + attened_x = attened_x.view(B, H * W, C) + + # Adaptive Interaction Module (AIM) + # C-Map (before sigmoid) + channel_map = self.channel_interaction(conv_x).permute(0, 2, 3, 1).contiguous().view(B, 1, C) + # S-Map (before sigmoid) + attention_reshape = attened_x.transpose(-2,-1).contiguous().view(B, C, H, W) + spatial_map = self.spatial_interaction(attention_reshape) + + # C-I + attened_x = attened_x * torch.sigmoid(channel_map) + # S-I + conv_x = torch.sigmoid(spatial_map) * conv_x + conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, H * W, C) + + x = attened_x + conv_x + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/ARTUnet_new.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/ARTUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9d210d46e2a4d73781b3b16b63a53fdc0f25b9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/ARTUnet_new.py @@ -0,0 +1,823 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +Unet的格式构建image restoration 网络。每个transformer block中都是dense self-attention和sparse self-attention交替连接。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True, visualize_attention_maps=False): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + self.visualize_attention_maps = visualize_attention_maps + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + # assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + # qkv: 3, B_, num_heads, N, C // num_heads + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + # B_, num_heads, N, C // num_heads * + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + attn_map = attn.detach().cpu().numpy().mean(axis=1) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) # (B * nP, nHead, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + # attn: B * nP, num_heads, N, N + # v: B * nP, num_heads, N, C // num_heads + # -attn-@-v-> B * nP, num_heads, N, C // num_heads + # -transpose-> B * nP, N, num_heads, C // num_heads + # -reshape-> B * nP, N, C + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, G ** 2, G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, Gh * Gw, Gh * Gw) + else: + x = self.attn(x, Gh, Gw, mask=attn_mask) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +class ConcatTransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + ################## + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + ################## + + x = torch.cat((x, y), dim=1) + + shortcut = x + x = self.norm1(x) + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, 2 * Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, 2 * G ** 2, 2 * G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, 2 * Gh * Gw, 2 * Gh * Gw) + else: + x = self.attn(x, 2 * Gh, Gw, mask=attn_mask) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + output = x + + # print(x.shape, Gh, Gw) + x = output[:, :Gh * Gw, :] + y = output[:, Gh * Gw:, :] + assert x.shape == y.shape + + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = y.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = y.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, y, attn_map + else: + return x, y + + def get_patches(self, x, x_size): + H, W = x_size + B, H, W, C = x.shape + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + return x, attn_mask, (Hd, Wd), (Gh, Gw), (pad_r, pad_b), G, I + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/ART_Restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/ART_Restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..4c23123f0ac1a8e82e2bf907658ce8d91951c3e4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/ART_Restormer.py @@ -0,0 +1,592 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer串联而成。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[4, 4, 4, 4], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + + ######################################### Encoder level1 ############################################ + ### ART encoder block + for layer in self.ART_encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + + ### Restormer encoder block + out_enc_level1 = rearrange(out_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = self.Restormer_encoder_level1(out_enc_level1) + # stack.append(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.ART_encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + out_enc_level2 = rearrange(out_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = self.Restormer_encoder_level2(out_enc_level2) + # stack.append(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.ART_encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + out_enc_level3 = rearrange(out_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = self.Restormer_encoder_level3(out_enc_level3) + # stack.append(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.ART_latent: + latent = layer(latent, [H // 8, W // 8]) + + ### Restormer encoder block + latent = rearrange(latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = self.Restormer_latent(latent) + # stack.append(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + out_dec_level3 = inp_dec_level3 + for layer in self.ART_decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + out_dec_level3 = rearrange(out_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + out_dec_level3 = self.Restormer_decoder_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + out_dec_level2 = inp_dec_level2 + for layer in self.ART_decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + out_dec_level2 = rearrange(out_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + out_dec_level2 = self.Restormer_decoder_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + out_dec_level1 = inp_dec_level1 + for layer in self.ART_decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + ### Restormer decoder block + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_decoder_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_refinement(out_dec_level1) + + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/ART_Restormer_v2.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/ART_Restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..de8f921e81d79cb7af14a75326f0ddaaf3b3f4c3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/ART_Restormer_v2.py @@ -0,0 +1,663 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer并联而成。 +输入特征分别经过ART和Restormer两个分支提取特征, 将两者的特征concat之后,用conv层变换得到该block的输出特征。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.enc_fuse_level1 = nn.Conv2d(int(dim * 2 ** 1), dim, kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.enc_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.enc_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + self.enc_fuse_level4 = nn.Conv2d(int(dim * 2 ** 4), int(dim * 2 ** 3), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.dec_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.dec_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.dec_fuse_level1 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + # self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + # bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + # self.fuse_refinement = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + # def forward(self, inp_img, aux_img): + # stack = [] + # bs, _, H, W = inp_img.shape + + # fuse_img = torch.cat((inp_img, aux_img), 1) + # inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + + def forward(self, inp_img): + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + + ######################################### Encoder level1 ############################################ + art_enc_level1 = inp_enc_level1 + restor_enc_level1 = inp_enc_level1 + + ### ART encoder block + for layer in self.ART_encoder_level1: + art_enc_level1 = layer(art_enc_level1, [H, W]) + + ### Restormer encoder block + restor_enc_level1 = rearrange(restor_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_enc_level1 = self.Restormer_encoder_level1(restor_enc_level1) ### (b, c, h, w) + + art_enc_level1 = rearrange(art_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = torch.cat((art_enc_level1, restor_enc_level1), 1) + out_enc_level1 = self.enc_fuse_level1(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + art_enc_level2 = inp_enc_level2 + restor_enc_level2 = inp_enc_level2 + + for layer in self.ART_encoder_level2: + art_enc_level2 = layer(art_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + restor_enc_level2 = rearrange(restor_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + restor_enc_level2 = self.Restormer_encoder_level2(restor_enc_level2) + + art_enc_level2 = rearrange(art_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = torch.cat((art_enc_level2, restor_enc_level2), 1) + out_enc_level2 = self.enc_fuse_level2(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + art_enc_level3 = inp_enc_level3 + restor_enc_level3 = inp_enc_level3 + + for layer in self.ART_encoder_level3: + art_enc_level3 = layer(art_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + restor_enc_level3 = rearrange(restor_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + restor_enc_level3 = self.Restormer_encoder_level3(restor_enc_level3) + + art_enc_level3 = rearrange(art_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = torch.cat((art_enc_level3, restor_enc_level3), 1) + out_enc_level3 = self.enc_fuse_level3(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + art_latent = inp_enc_level4 + restor_latent = inp_enc_level4 + + for layer in self.ART_latent: + art_latent = layer(art_latent, [H // 8, W // 8]) + + ### Restormer encoder block + restor_latent = rearrange(restor_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + restor_latent = self.Restormer_latent(restor_latent) + + art_latent = rearrange(art_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = torch.cat((art_latent, restor_latent), 1) + latent = self.enc_fuse_level4(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + art_dec_level3 = inp_dec_level3 + restor_dec_level3 = inp_dec_level3 + + for layer in self.ART_decoder_level3: + art_dec_level3 = layer(art_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + restor_dec_level3 = rearrange(restor_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + restor_dec_level3 = self.Restormer_decoder_level3(restor_dec_level3) + + art_dec_level3 = rearrange(art_dec_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_dec_level3 = torch.cat((art_dec_level3, restor_dec_level3), 1) + out_dec_level3 = self.dec_fuse_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + art_dec_level2 = inp_dec_level2 + restor_dec_level2 = inp_dec_level2 + + for layer in self.ART_decoder_level2: + art_dec_level2 = layer(art_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + restor_dec_level2 = rearrange(restor_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + restor_dec_level2 = self.Restormer_decoder_level2(restor_dec_level2) + + art_dec_level2 = rearrange(art_dec_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_dec_level2 = torch.cat((art_dec_level2, restor_dec_level2), 1) + out_dec_level2 = self.dec_fuse_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + art_dec_level1 = inp_dec_level1 + restor_dec_level1 = inp_dec_level1 + + for layer in self.ART_decoder_level1: + art_dec_level1 = layer(art_dec_level1, [H, W]) + + ### Restormer decoder block + restor_dec_level1 = rearrange(restor_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_dec_level1 = self.Restormer_decoder_level1(restor_dec_level1) + + art_dec_level1 = rearrange(art_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = torch.cat((art_dec_level1, restor_dec_level1), 1) + out_dec_level1 = self.dec_fuse_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +# def build_model(args): +# return ART_Restormer( +# inp_channels=2, +# out_channels=1, +# dim=32, +# num_art_blocks=[2, 4, 4, 6], +# num_restormer_blocks=[2, 2, 2, 2], +# num_refinement_blocks=4, +# heads=[1, 2, 4, 8], +# window_size=[10, 10, 10, 10], mlp_ratio=4., +# qkv_bias=True, qk_scale=None, +# interval=[24, 12, 6, 3], +# ffn_expansion_factor = 2.66, +# LayerNorm_type = 'WithBias', ## Other option 'BiasFree' +# bias=False) + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/ARTfuse_layer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/ARTfuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..78af0b27528b50db897936f6b8b4eee51d566b1c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/ARTfuse_layer.py @@ -0,0 +1,705 @@ +""" +ART layer in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input to the Attention layer:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + # print("q range:", q.max(), q.min()) + # print("k range:", k.max(), k.min()) + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + # print("attention matrix:", attn.shape, attn) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + # print("feature before proj-layer:", x.max(), x.min()) + x = self.proj(x) + # print("feature after proj-layer:", x.max(), x.min()) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pre_norm=True): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + # self.conv = nn.Conv2d(dim, dim, kernel_size=1) + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + # self.conv = ConvBlock(self.dim, self.dim, drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.bn_norm = nn.BatchNorm2d(dim) + self.pre_norm = pre_norm + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + # x = self.conv(x) + shortcut = x + if self.pre_norm: + # print("using pre_norm") + x = self.norm1(x) ## 归一化之后特征范围变得非常小 + x = x.view(B, H, W, C) + # print("normalized input range:", x.max(), x.min(), x.mean()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + # print("range of feature before layer:", x.max(), x.min(), x.mean()) + # x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + + # x_bn = self.bn_norm(x.permute(0, 3, 1, 2)) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + # print("[ART layer] input and output feature difference:", torch.mean(torch.abs(shortcut - x))) + # print("range of feature before ART layer:", shortcut.max(), shortcut.min(), shortcut.mean()) + # print("range of feature after ART layer:", x.max(), x.min(), x.mean()) + if self.pre_norm: + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + # print("using post_norm") + x = self.norm1(shortcut + self.drop_path(x)) + x = self.norm2(x + self.drop_path(self.mlp(x))) + # x = shortcut + self.drop_path(x) + # x = x + self.drop_path(self.mlp(x)) + + # print("range of fused feature:", x.max(), x.min()) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + return self.layers(image) + + + +########################################################################## +class Cross_TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +class Cross_TransformerBlock_v2(nn.Module): + r""" ART Transformer Block. + 将两个模态的特征沿着channel方向concat, 一起取window. 之后把取出来的两个模态中同一个window的所有patches合并,送到transformer处理。 + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + xy = torch.cat((x, y), -1) + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + xy = F.pad(xy, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + xy = xy.reshape(B, Hd // G, G, Wd // G, G, 2*C).permute(0, 1, 3, 2, 4, 5).contiguous() + xy = xy.reshape(B * Hd * Wd // G ** 2, G ** 2, 2*C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + xy = xy.reshape(B, Gh, I, Gw, I, 2*C).permute(0, 2, 4, 1, 3, 5).contiguous() + xy = xy.reshape(B * I * I, Gh * Gw, 2*C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + x = xy[:, :, :C] + y = xy[:, :, C:] + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/DataConsistency.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/DataConsistency.py new file mode 100644 index 0000000000000000000000000000000000000000..5f460e9acabcb33e1a3f812eb9acaeb337da446c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/DataConsistency.py @@ -0,0 +1,48 @@ +""" +Created: DataConsistency @ Xiyang Cai, 2023/09/09 + +Data consistency layer for k-space signal. + +Ref: DataConsistency in DuDoRNet (https://github.com/bbbbbbzhou/DuDoRNet) + +""" + +import torch +from torch import nn +from einops import repeat + + +def data_consistency(k, k0, mask): + """ + k - input in k-space + k0 - initially sampled elements in k-space + mask - corresponding nonzero location + """ + out = (1 - mask) * k + mask * k0 + return out + + +class DataConsistency(nn.Module): + """ + Create data consistency operator + """ + + def __init__(self): + super(DataConsistency, self).__init__() + + def forward(self, k, k0, mask): + """ + k - input in frequency domain, of shape (n, nx, ny, 2) + k0 - initially sampled elements in k-space + mask - corresponding nonzero location (n, 1, len, 1) + """ + + if k.dim() != 4: # input is 2D + raise ValueError("error in data consistency layer!") + + # mask = repeat(mask.squeeze(1, 3), 'b x -> b x y c', y=k.shape[1], c=2) + mask = torch.tile(mask, (1, mask.shape[2], 1, k.shape[-1])) ### [n, 320, 320, 2] + # print("k and k0 shape:", k.shape, k0.shape) + out = data_consistency(k, k0, mask) + + return out, mask diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/DuDo_mUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/DuDo_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a56069a27f0cdd392482b41dc4e3b79252295487 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/DuDo_mUnet.py @@ -0,0 +1,359 @@ +""" +Xiaohan Xing, 2023/10/25 +传入两个模态的kspace data, 经过kspace network进行重建。 +然后把重建之后的kspace data变换回图像,然后送到image domain的网络进行图像重建。 +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + +from .Kspace_mUnet import kspace_mmUnet + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = kspace_ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = kspace_ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = kspace_ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = kspace_ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class DuDo_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 3, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + # self.kspace_net = mmConvKSpace() + self.kspace_net = kspace_mmUnet(args) + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + ) + ) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + stack = [] + feature_stack = [] + # output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + output = torch.cat((image, LR_image, aux_image), 1) #.detach() + + # print("LR image range:", LR_image.max(), LR_image.min()) + # print("kspace_recon image range:", image.max(), image.min()) + # print("aux_image range:", aux_image.max(), aux_image.min()) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output, image, recon_kspace, masked_kspace, data_stats + + + +class kspace_ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return DuDo_mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/DuDo_mUnet_ARTfusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/DuDo_mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6df67b63140ca51bd7b8664151ab4b1b7a6305d5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/DuDo_mUnet_ARTfusion.py @@ -0,0 +1,369 @@ +""" +Xiaohan Xing, 2023/11/08 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + ARTfusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + output1, output2 = aux_image.detach(), LR_image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, image, recon_kspace, masked_kspace, data_stats + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/DuDo_mUnet_CatFusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/DuDo_mUnet_CatFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c2b86b6e2031b1af07ab7cb58efc182b7a2673 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/DuDo_mUnet_CatFusion.py @@ -0,0 +1,338 @@ +""" +Xiaohan Xing, 2023/11/10 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + multi layer concat fusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_CATfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t2_gt: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + t2_image = t2_gt * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + # # print("aux_image range:", aux_image.max(), aux_image.min()) + # print("LR_image range:", LR_image.max(), LR_image.min()) + # print("kspace recon_img range:", image.max(), image.min()) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + t2_gt_image = self.normalize(t2_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + t2_gt_image = torch.clamp(t2_gt_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + # output1, output2 = aux_image.detach(), LR_image.detach() + output1, output2 = aux_image.detach(), image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, aux_image, t2_gt_image, image, LR_image, recon_kspace, masked_kspace, data_stats + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_CATfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/Kspace_ConvNet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/Kspace_ConvNet.py new file mode 100644 index 0000000000000000000000000000000000000000..918cbe598a62653394c8c4eaf99cd10a45c064ca --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/Kspace_ConvNet.py @@ -0,0 +1,140 @@ +""" +2023/10/05 +Xiaohan Xing +Build a simple network with several convolution layers. +Input: concat (undersampled kspace of FS-PD, fully-sampled kspace of PD modality) +Output: interpolated kspace of FS-PD +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, args, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + # self.conv_blocks = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # self.conv_blocks.append(ConvBlock(self.chans, self.chans * 2, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans * 2, self.chans, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans, self.out_chans, drop_prob)) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + # print("output of the three conv blocks:", output1.shape, output2.shape, output3.shape) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("conv1_output range:", output1.max(), output1.min()) + # print("conv2_output range:", output2.max(), output2.min()) + # print("conv3_output range:", output3.max(), output3.min()) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + # print("output before DC layer:", output.shape) + # print("output range before DC layer:", output.max(), output.min()) + # output = output + kspace ## residual connection + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + # mask_output = output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +def build_model(args): + return mmConvKSpace(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/Kspace_mUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/Kspace_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2db9a357798be4abd4cfbd44e9207555770b0456 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/Kspace_mUnet.py @@ -0,0 +1,202 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # print("output:", output.shape) + # print("kspace:", kspace.shape) + # print("mask:", mask.shape) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/Kspace_mUnet_new.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/Kspace_mUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ccfd28dc198ffeff2259a5fbed45dabdc76819 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/Kspace_mUnet_new.py @@ -0,0 +1,200 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # print("using Unet without Relu layers") + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + # output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # # print("output:", output.shape) + # # print("kspace:", kspace.shape) + # # print("mask:", mask.shape) + # mask_output = mask * output + + return output # .permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/MINet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..4eeeb063b12be1dc594e8e076e6a59e454fcfa0b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/MINet.py @@ -0,0 +1,356 @@ +""" +Compare with the method "Multi-contrast MRI Super-Resolution via a Multi-stage Integration Network". +The code is from https://github.com/chunmeifeng/MINet/edit/main/fastmri/models/MINet.py. +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F + + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 1 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + # print("modules_tail:", modules_tail) + # self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Sigmoid()) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + + + +class MINet(nn.Module): + def __init__(self, args, n_resgroups=3, n_resblocks=3, n_feats=64): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + # print("x1:", x1.shape, "x2:", x2.shape) + + ### 源代码中是super-resolution任务,所以self.tail对图像进行unsampling. 我们做reconstruction把这步去掉 + x2 = self.tail(x2) + + + resT1 = x1 + resT2 = x2 + + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + # print("t1s:", [item.shape for item in t1s]) + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + # print("resT1 range:", resT1.max(), resT1.min()) + # print("resT2 range:", resT2.max(), resT2.min()) + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1, x2 #x1=pd x2=pdfs + + +def build_model(args): + return MINet(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/MINet_common.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/MINet_common.py new file mode 100644 index 0000000000000000000000000000000000000000..631f3036db4edf8eab00d70c66d2058a3b65cd59 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/MINet_common.py @@ -0,0 +1,90 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias) + +class MeanShift(nn.Conv2d): + def __init__( + self, rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): + + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std + for p in self.parameters(): + p.requires_grad = False + +class BasicBlock(nn.Sequential): + def __init__( + self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, + bn=True, act=nn.ReLU(True)): + + m = [conv(in_channels, out_channels, kernel_size, bias=bias)] + if bn: + m.append(nn.BatchNorm2d(out_channels)) + if act is not None: + m.append(act) + + super(BasicBlock, self).__init__(*m) + +class ResBlock(nn.Module): + def __init__( + self, conv, n_feats, kernel_size, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + + return res + + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + # print("upsampler:", m) + + super(Upsampler, self).__init__(*m) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/SANet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/SANet.py new file mode 100644 index 0000000000000000000000000000000000000000..bfac3590e248c9532b002885c6c0b7a55ab1400a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/SANet.py @@ -0,0 +1,392 @@ +""" +This script contains the codes of "Exploring Separable Attention for Multi-Contrast MR Image Super-Resolution (TNNLS 2023)" +https://github.com/chunmeifeng/SANet/blob/main/fastmri/models/SANet.py +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import os + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,scale,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups #10 + self.n_resblocks = n_resblocks #20 + self.n_feats = n_feats #64 + kernel_size = 3 + reduction = 16 #16 + # scale = args.scale[0] + + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class Seatt(nn.Module): + def __init__(self, in_c): + super(Seatt, self).__init__() + self.reduce = nn.Conv2d(in_c * 2, 32, 1) + self.ff_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.bf_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.rgbd_pred_layer = Pred_Layer(32 * 2) + self.convq = nn.Conv2d(32, 32, 3, 1, 1) + + def forward(self, rgb_feat, dep_feat, pred): + feat = torch.cat((rgb_feat, dep_feat), 1) + + # vis_attention_map_rgb_feat(rgb_feat[0][0]) + # vis_attention_map_dep_feat(dep_feat[0][0]) + # vis_attention_map_feat(feat[0][0]) + + + feat = self.reduce(feat) + [_, _, H, W] = feat.size() + pred = torch.sigmoid(pred) + + ni_pred = 1 - pred + + ff_feat = self.ff_conv(feat * pred) + bf_feat = self.bf_conv(feat * (1 - pred)) + new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1)) + + return new_pred + +class SANet(nn.Module): + def __init__(self, args, scale=1, n_resgroups=3, n_resblocks=3, n_feats=64): + super(SANet, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + + cs = [64,64,64,64] + self.Seatts = nn.ModuleList([Seatt(c) for c in cs]) + + self.net1 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + print("nlayer:",nlayer) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=3, padding=1) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2, Seatt, map_conv in zip(self.net1.body._modules.items(), self.net2.body._modules.items(), self.Seatts, self.map_convs): + + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + pred = map_conv(resT1) + res = Seatt(resT1,resT2,pred) + + resT2 = res+resT2 + + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + + return x1,x2 + + +def build_model(args): + return SANet(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/SwinFuse_layer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/SwinFuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3dca8ba793d92961bc3f1d987e211fe542b303 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/SwinFuse_layer.py @@ -0,0 +1,1310 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + # print("x shape:", x.shape) + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + # print("num heads:", self.num_heads) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + # print("x_windows:", x_windows.shape) + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + + # # 不使用残差连接 + # x = self.residual_group_A(x, x_size) + # y = self.residual_group_B(y, x_size) + + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion_layer(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Fusion_num_heads=[6, 6], + window_size=7, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, img_range=1., resi_connection='1conv', + **kwargs): + super(SwinFusion_layer, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + + ### fusion layer. + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + # print("fusion layers:", len(self.layers_Fusion)) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + + def forward_features_Fusion(self, x, y): + + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + # x = torch.cat([x, y], 1) + # ## Downsample the feature in the channel dimension + # x = self.lrelu(self.conv_after_body_Fusion(x)) + # print("x range after fusion:", x.max(), x.min()) + # print("y range after fusion:", y.max(), y.min()) + # print("x:", x.shape) + + return x, y + + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + print("modality1 input:", x.max(), x.min()) + print("modality2 input:", y.max(), y.min()) + + # multi-modal feature fusion. + x, y = self.forward_features_Fusion(x, y) ### 返回两个模态融合之后的features. + print("modality1 output:", x.max(), x.min()) + print("modality2 output:", y.max(), y.min()) + + return x[:, :, :H, :W], y[:, :, :H, :W] + + + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + + +if __name__ == '__main__': + width, height = 120, 120 + swinfuse = SwinFusion_layer(img_size=(width, height), window_size=8, depths=[6, 6, 6], embed_dim=60, num_heads=[6, 6, 6]) + # print("SwinFusion layer:", swinfuse) + + A = torch.randn((4, 60, width, height)) + B = torch.randn((4, 60, width, height)) + + x = swinfuse(A, B) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/SwinFusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/SwinFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..3e72c2a3e8859423f3544043e8b9feaec002d0e2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/SwinFusion.py @@ -0,0 +1,1468 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # print("qkv after reshape:", qkv.shape) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + # 不使用残差连接 + # print("input x shape:", x.shape) + x = self.residual_group_A(x, x_size) + y = self.residual_group_B(y, x_size) + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Re_depths=[4], + Ex_num_heads=[6], Fusion_num_heads=[6, 6], Re_num_heads=[6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinFusion, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + print('in_chans: ', in_chans) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + ####修改shallow feature extraction 网络, 修改为2个3x3的卷积#### + self.conv_first1_A = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first1_B = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first2_A = nn.Conv2d(embed_dim_temp, embed_dim, 3, 1, 1) + self.conv_first2_B = nn.Conv2d(embed_dim_temp, embed_dim_temp, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + self.Re_num_layers = len(Re_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + dpr_Re = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Re_depths))] # stochastic depth decay rule + # build Residual Swin Transformer blocks (RSTB) + self.layers_Ex_A = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_A.append(layer) + self.norm_Ex_A = norm_layer(self.num_features) + + self.layers_Ex_B = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_B.append(layer) + self.norm_Ex_B = norm_layer(self.num_features) + + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + self.layers_Re = nn.ModuleList() + for i_layer in range(self.Re_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Re_depths[i_layer], + num_heads=Re_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Re[sum(Re_depths[:i_layer]):sum(Re_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Re.append(layer) + self.norm_Re = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last1 = nn.Conv2d(embed_dim, embed_dim_temp, 3, 1, 1) + self.conv_last2 = nn.Conv2d(embed_dim_temp, int(embed_dim_temp/2), 3, 1, 1) + self.conv_last3 = nn.Conv2d(int(embed_dim_temp/2), num_out_ch, 3, 1, 1) + + self.tanh = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features_Ex_A(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_A: + x = layer(x, x_size) + + x = self.norm_Ex_A(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward_features_Ex_B(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_B: + x = layer(x, x_size) + + x = self.norm_Ex_B(x) # B L C + x = self.patch_unembed(x, x_size) + return x + + def forward_features_Fusion(self, x, y): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + # print("x token:", x.shape, "y token:", y.shape) + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + x = torch.cat([x, y], 1) + ## Downsample the feature in the channel dimension + x = self.lrelu(self.conv_after_body_Fusion(x)) + + return x + + def forward_features_Re(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Re: + x = layer(x, x_size) + + x = self.norm_Re(x) # B L C + x = self.patch_unembed(x, x_size) + # Convolution + x = self.lrelu(self.conv_last1(x)) + x = self.lrelu(self.conv_last2(x)) + x = self.conv_last3(x) + return x + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + # self.mean_A = self.mean.type_as(x) + # self.mean_B = self.mean.type_as(y) + # self.mean = (self.mean_A + self.mean_B) / 2 + + # x = (x - self.mean_A) * self.img_range + # y = (y - self.mean_B) * self.img_range + + # Feedforward + x = self.forward_features_Ex_A(x) + y = self.forward_features_Ex_B(y) + # print("x before fusion:", x.shape) + # print("y before fusion:", y.shape) + x = self.forward_features_Fusion(x, y) + x = self.forward_features_Re(x) + x = self.tanh(x) + + # x = x / self.img_range + self.mean + # print("H:", H, "upscale:", self.upscale) + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='convs') + + +if __name__ == '__main__': + model = SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='convs') + # print(model) + + A = torch.randn((1, 1, 240, 240)) + B = torch.randn((1, 1, 240, 240)) + + x = model(A, B) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/TransFuse.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/TransFuse.py new file mode 100644 index 0000000000000000000000000000000000000000..56ca3d2b24f382603c876e4f6909363a9794ffa3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/TransFuse.py @@ -0,0 +1,155 @@ +import math +from collections import deque + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from torchvision import models + + +class SelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + """ + + def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(n_embd, n_embd) + self.query = nn.Linear(n_embd, n_embd) + self.value = nn.Linear(n_embd, n_embd) + # regularization + self.attn_drop = nn.Dropout(attn_pdrop) + self.resid_drop = nn.Dropout(resid_pdrop) + # output projection + self.proj = nn.Linear(n_embd, n_embd) + self.n_head = n_head + + def forward(self, x): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y + + +class Block(nn.Module): + """ an unassuming Transformer block """ + + def __init__(self, n_embd, n_head, block_exp, attn_pdrop, resid_pdrop): + super().__init__() + self.ln1 = nn.LayerNorm(n_embd) + self.ln2 = nn.LayerNorm(n_embd) + self.attn = SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + self.mlp = nn.Sequential( + nn.Linear(n_embd, block_exp * n_embd), + nn.ReLU(True), # changed from GELU + nn.Linear(block_exp * n_embd, n_embd), + nn.Dropout(resid_pdrop), + ) + + def forward(self, x): + B, T, C = x.size() + + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + + return x + + +class TransFuse_layer(nn.Module): + """ the full GPT language model, with a context size of block_size """ + + def __init__(self, n_embd, n_head, block_exp, n_layer, + num_anchors, seq_len=1, + embd_pdrop=0.1, attn_pdrop=0.1, resid_pdrop=0.1): + super().__init__() + self.n_embd = n_embd + self.seq_len = seq_len + self.vert_anchors = num_anchors + self.horz_anchors = num_anchors + + # positional embedding parameter (learnable), image + lidar + self.pos_emb = nn.Parameter(torch.zeros(1, 2 * seq_len * self.vert_anchors * self.horz_anchors, n_embd)) + + self.drop = nn.Dropout(embd_pdrop) + + # transformer + self.blocks = nn.Sequential(*[Block(n_embd, n_head, + block_exp, attn_pdrop, resid_pdrop) + for layer in range(n_layer)]) + + # decoder head + self.ln_f = nn.LayerNorm(n_embd) + + self.block_size = seq_len + self.apply(self._init_weights) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + + def forward(self, m1, m2): + """ + Args: + m1 (tensor): B*seq_len, C, H, W + m2 (tensor): B*seq_len, C, H, W + """ + + bz = m2.shape[0] // self.seq_len + h, w = m2.shape[2:4] + + # forward the image model for token embeddings + m1 = m1.view(bz, self.seq_len, -1, h, w) + m2 = m2.view(bz, self.seq_len, -1, h, w) + + # pad token embeddings along number of tokens dimension + token_embeddings = torch.cat([m1, m2], dim=1).permute(0,1,3,4,2).contiguous() + token_embeddings = token_embeddings.view(bz, -1, self.n_embd) # (B, an * T, C) + + # add (learnable) positional embedding for all tokens + x = self.drop(self.pos_emb + token_embeddings) # (B, an * T, C) + x = self.blocks(x) # (B, an * T, C) + x = self.ln_f(x) # (B, an * T, C) + x = x.view(bz, 2 * self.seq_len, self.vert_anchors, self.horz_anchors, self.n_embd) + x = x.permute(0,1,4,2,3).contiguous() # same as token_embeddings + + m1_out = x[:, :self.seq_len, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + m2_out = x[:, self.seq_len:, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + print("modality1 output:", m1_out.max(), m1_out.min()) + print("modality2 output:", m2_out.max(), m2_out.min()) + + return m1_out, m2_out + + +if __name__ == "__main__": + + feature1 = torch.randn((4, 512, 20, 20)) + feature2 = torch.randn((4, 512, 20, 20)) + model = TransFuse_layer(n_embd=512, n_head=4, block_exp=4, n_layer=8, num_anchors=20, seq_len=1) + print("TransFuse_layer:", model) + feat1, feat2 = model(feature1, feature2) + print(feat1.shape, feat2.shape) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/Unet_ART.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/Unet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b0f834270a921ae4eb428c92b3e782fbed19b3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/Unet_ART.py @@ -0,0 +1,243 @@ +""" +2023/07/06, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + output = conv_layer(output) + + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..334f1821a9b59e692641f70cdc61e7655b1a9c89 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/__init__.py @@ -0,0 +1,114 @@ +from .unet import build_model as UNET +from .mmunet import build_model as MUNET +from .humus_net import build_model as HUMUS +from .MINet import build_model as MINET +from .SANet import build_model as SANET +from .mtrans_net import build_model as MTRANS +from .unimodal_transformer import build_model as TRANS +from .cnn_transformer_new import build_model as CNN_TRANS +from .cnn import build_model as CNN +from .unet_transformer import build_model as UNET_TRANSFORMER +from .swin_transformer import build_model as SWIN_TRANS +from .restormer import build_model as RESTORMER + + +from .cnn_late_fusion import build_model as MULTI_CNN +from .mmunet_late_fusion import build_model as MMUNET_LATE +from .mmunet_early_fusion import build_model as MMUNET_EARLY +from .trans_unet.trans_unet import build_model as TRANSUNET +from .SwinFusion import build_model as SWINMULTI +from .mUnet_transformer import build_model as MUNET_TRANSFORMER +from .munet_multi_transfuse import build_model as MUNET_TRANSFUSE +from .munet_swinfusion import build_model as MUNET_SWINFUSE +from .munet_multi_concat import build_model as MUNET_CONCAT +from .munet_concat_decomp import build_model as MUNET_CONCAT_DECOMP +from .munet_multi_sum import build_model as MUNET_SUM +# from .mUnet_ARTfusion_new import build_model as MUNET_ART_FUSION +from .mUnet_ARTfusion import build_model as MUNET_ART_FUSION +from .mUnet_Restormer_fusion import build_model as MUNET_RESTORMER_FUSION +from .mUnet_ARTfusion_SeqConcat import build_model as MUNET_ART_FUSION_SEQ + +from .ARTUnet import build_model as ART +from .Unet_ART import build_model as UNET_ART +from .unet_restormer import build_model as UNET_RESTORMER +from .mmunet_ART import build_model as MMUNET_ART +from .mmunet_restormer import build_model as MMUNET_RESTORMER +from .mmunet_restormer_v2 import build_model as MMUNET_RESTORMER_V2 +from .mmunet_restormer_3blocks import build_model as MMUNET_RESTORMER_SMALL + +from .mARTUnet import build_model as ART_MULTI_INPUT + +from .ART_Restormer import build_model as ART_RESTORMER +from .ART_Restormer_v2 import build_model as ART_RESTORMER_V2 +from .swinIR import build_model as SWINIR + +from .Kspace_ConvNet import build_model as KSPACE_CONVNET +from .Kspace_mUnet_new import build_model as KSPACE_MUNET +from .kspace_mUnet_concat import build_model as KSPACE_MUNET_CONCAT +from .kspace_mUnet_AttnFusion import build_model as KSPACE_MUNET_ATTNFUSE + +from .DuDo_mUnet import build_model as DUDO_MUNET +from .DuDo_mUnet_ARTfusion import build_model as DUDO_MUNET_ARTFUSION +from .DuDo_mUnet_CatFusion import build_model as DUDO_MUNET_CONCAT + + +model_factory = { + 'unet_single': UNET, + 'humus_single': HUMUS, + 'transformer_single': TRANS, + 'cnn_transformer': CNN_TRANS, + 'cnn_single': CNN, + 'swin_trans_single': SWIN_TRANS, + 'trans_unet': TRANSUNET, + 'unet_transformer': UNET_TRANSFORMER, + 'restormer': RESTORMER, + 'unet_art': UNET_ART, + 'unet_restormer': UNET_RESTORMER, + 'art': ART, + 'art_restormer': ART_RESTORMER, + 'art_restormer_v2': ART_RESTORMER_V2, + 'swinIR': SWINIR, + + + 'munet_transformer': MUNET_TRANSFORMER, + 'munet_transfuse': MUNET_TRANSFUSE, + 'cnn_late_multi': MULTI_CNN, + 'unet_multi': MUNET, + 'unet_late_multi':MMUNET_LATE, + 'unet_early_multi':MMUNET_EARLY, + 'munet_ARTfusion': MUNET_ART_FUSION, + 'munet_restormer_fusion': MUNET_RESTORMER_FUSION, + + + 'minet_multi': MINET, + 'sanet_multi': SANET, + 'mtrans_multi': MTRANS, + 'swin_fusion': SWINMULTI, + 'munet_swinfuse': MUNET_SWINFUSE, + 'munet_concat': MUNET_CONCAT, + 'munet_concat_decomp': MUNET_CONCAT_DECOMP, + 'munet_sum': MUNET_SUM, + 'mmunet_art': MMUNET_ART, + 'mmunet_restormer': MMUNET_RESTORMER, + 'mmunet_restormer_small': MMUNET_RESTORMER_SMALL, + 'mmunet_restormer_v2': MMUNET_RESTORMER_V2, + + 'art_multi_input': ART_MULTI_INPUT, + + 'munet_ARTfusion_SeqConcat': MUNET_ART_FUSION_SEQ, + + 'kspace_ConvNet': KSPACE_CONVNET, + 'kspace_munet': KSPACE_MUNET, + 'kspace_munet_concat': KSPACE_MUNET_CONCAT, + 'kspace_munet_AttnFusion': KSPACE_MUNET_ATTNFUSE, + + 'dudo_munet': DUDO_MUNET, + 'dudo_munet_ARTfusion': DUDO_MUNET_ARTFUSION, + 'dudo_munet_concat': DUDO_MUNET_CONCAT, +} + + +def build_model_from_name(args): + assert args.model_name in model_factory.keys(), 'unknown model name' + + return model_factory[args.model_name](args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/cnn.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5c130178e7cdabaaef8686da0cec341265f3c43d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/cnn.py @@ -0,0 +1,246 @@ +""" +把our method中的单模态CNN_Transformer中的transformer换成几个conv_block用来下采样提取特征, 看是否能够达到和Unet相似的效果。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Reconstructor(nn.Module): + def __init__(self): + super(CNN_Reconstructor, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + nlatent = self.encoder.output_dim + self.decoder = Decoder(n_upsample=4, n_res=1, dim=nlatent, output_dim=1, pad_type='zero') + + + def forward(self, m): + #4,1,240,240 + feature_output = self.encoder.model(m) # [4, 256, 60, 60] + # print("output of the encoder:", feature_output.shape) + + output = self.decoder(feature_output) + + return output + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Reconstructor() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/cnn_late_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/cnn_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..58dacad6910b6e896d00a0d36c9ea60e8d77217f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/cnn_late_fusion.py @@ -0,0 +1,196 @@ +""" +在lequan的代码上,去掉transformer-based feature fusion部分。 +直接用CNN提取特征之后concat作为multi-modal representation. 用普通的decoder进行图像重建。 +把T1作为guidance modality, T2作为target modality. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers1 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + self.down_sample_layers2 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers1.append(ConvBlock(ch, ch * 2, drop_prob)) + self.down_sample_layers2.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + + # self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # ch = chans + # for _ in range(num_pool_layers - 1): + # self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + # ch *= 2 + # self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv = nn.Sequential(nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), nn.Tanh()) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = torch.cat([image, aux_image], 1) + # for layer in self.down_sample_layers: + # output = layer(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + + # apply down-sampling layers + output = image + for layer in self.down_sample_layers1: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + output_m1 = output + + output = aux_image + for layer in self.down_sample_layers2: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + output_m2 = output + + # print("two modality outputs:", output_m1.shape, output.shape) + output = torch.cat([output_m1, output_m2], 1) + + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv in self.up_transpose_conv: + output = transpose_conv(output) + + output = self.up_conv(output) + # print("model output:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/cnn_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/cnn_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..261d4d214e35c3a9be5d28bc426ce7280124723c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/cnn_transformer.py @@ -0,0 +1,289 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=2, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=2, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 60 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 4 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + feature_output = self.upsampling1(feature_output) # [4,512,30,30] + feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/cnn_transformer_new.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/cnn_transformer_new.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b2dea1e7cc8d456a684b6c839f052383999e84 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/cnn_transformer_new.py @@ -0,0 +1,290 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +2023/05/23: 之前是从尺寸为60*60的特征图上切patch, 现在可以尝试直接用conv_blocks提取15*15的特征图,然后按照patch_size = 1切成225个patches。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=4, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 1 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + # feature_output = self.upsampling1(feature_output) # [4,512,30,30] + # feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/humus_net.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/humus_net.py new file mode 100644 index 0000000000000000000000000000000000000000..d23701634f849b414866d71b3c23c6d07f3d6437 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/humus_net.py @@ -0,0 +1,1070 @@ +""" +HUMUS-Block +Hybrid Transformer-convolutional Multi-scale denoiser. + +We use parts of the code from +SwinIR (https://github.com/JingyunLiang/SwinIR/blob/main/models/network_swinir.py) +Swin-Unet (https://github.com/HuCaoFighting/Swin-Unet/blob/main/networks/swin_transformer_unet_skip_expand_decoder_sys.py) + +HUMUS-Net中是对kspace data操作, 多个unrolling cascade, 每个cascade中都利用HUMUS-block对图像进行denoising. +因为我们的方法输入是Under-sampled images, 所以不需要进行unrolling, 直接使用一个HUMUS-block进行MRI重建。 +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from einops import rearrange + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # print("x shape:", x.shape) + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + +class PatchExpandSkip(nn.Module): + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.expand = nn.Linear(dim, 2*dim, bias=False) + self.channel_mix = nn.Linear(dim, dim // 2, bias=False) + self.norm = norm_layer(dim // 2) + + def forward(self, x, skip): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) + x = x.view(B,-1,C//4) + x = torch.cat([x, skip], dim=-1) + x = self.channel_mix(x) + x = self.norm(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, torch.tensor(x_size)) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + block_type: + D: downsampling block, + U: upsampling block + B: bottleneck block + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv', block_type='B'): + super(RSTB, self).__init__() + self.dim = dim + conv_dim = dim // (patch_size ** 2) + divide_out_ch = 1 + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + ) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(conv_dim, conv_dim // divide_out_ch, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(conv_dim, conv_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // divide_out_ch, 3, 1, 1)) + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.block_type = block_type + if block_type == 'B': + self.reshape = torch.nn.Identity() + elif block_type == 'D': + self.reshape = PatchMerging(input_resolution, dim) + elif block_type == 'U': + self.reshape = PatchExpandSkip([res // 2 for res in input_resolution], dim * 2) + else: + raise ValueError('Unknown RSTB block type.') + + def forward(self, x, x_size, skip=None): + if self.block_type == 'U': + assert skip is not None, "Skip connection is required for patch expand" + x = self.reshape(x, skip) + + out = self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size)) + block_out = self.patch_embed(out) + x + + if self.block_type == 'D': + block_out = (self.reshape(block_out), block_out) # return skip connection + + return block_out + +class PatchEmbed(nn.Module): + r""" Fixed Image to Patch Embedding. Flattens image along spatial dimensions + without learned projection mapping. Only supports 1x1 patches. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + embed_dim (int): Number of output channels. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + +class PatchEmbedLearned(nn.Module): + r""" Image to Patch Embedding with arbitrary patch size and learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Linear(in_chans * patch_size[0] * patch_size[1], embed_dim) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.patch_to_vec(x) + x = self.proj(x) + if self.norm is not None: + x = self.norm(x) + return x + + def patch_to_vec(self, x): + """ + Args: + x: (B, C, H, W) + Returns: + patches: (B, num_patches, patch_size * patch_size * C) + """ + B, C, H, W = x.shape + x = x.permute(0, 2, 3, 1).contiguous() + x = x.view(B, H // self.patch_size[0], self.patch_size[0], W // self.patch_size[1], self.patch_size[1], C) + patches = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.patch_size[0], self.patch_size[1], C) + patches = patches.contiguous().view(B, self.num_patches, self.patch_size[0] * self.patch_size[1] * C) + return patches + + +class PatchUnEmbed(nn.Module): + r""" Fixed Image to Patch Unembedding via reshaping. Only supports patch size 1x1. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + +class PatchUnEmbedLearned(nn.Module): + r""" Patch Embedding to Image with learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans=None, out_chans=None, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.out_chans = out_chans + + self.proj_up = nn.Linear(in_chans, in_chans * patch_size[0] * patch_size[1]) + + def forward(self, x, x_size=None): + B, HW, C = x.shape + x = self.proj_up(x) + x = self.vec_to_patch(x) + return x + + def vec_to_patch(self, x): + B, HW, C = x.shape + x = x.view(B, self.patches_resolution[0], self.patches_resolution[1], C).contiguous() + x = x.view(B, self.patches_resolution[0], self.patch_size[0], self.patches_resolution[1], self.patch_size[1], C // (self.patch_size[0] * self.patch_size[1])).contiguous() + x = x.view(B, self.img_size[0], self.img_size[1], -1).contiguous().permute(0, 3, 1, 2) + return x + + +class HUMUSNet(nn.Module): + r""" HUMUS-Block + Args: + img_size (int | tuple(int)): Input image size. + patch_size (int | tuple(int)): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depths (tuple(int)): Depth of each Swin Transformer layer in encoder and decoder paths. + num_heads (tuple(int)): Number of attention heads in different layers of encoder and decoder. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + ape (bool): If True, add absolute position embedding to the patch embedding. + patch_norm (bool): If True, add normalization after patch embedding. + use_checkpoint (bool): Whether to use checkpointing to save memory. + img_range: Image range. 1. or 255. + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + conv_downsample_first: use convolutional downsampling before MUST to reduce compute load on Transformers + """ + + def __init__(self, args, + img_size=[240, 240], + in_chans=1, + patch_size=1, + embed_dim=66, + depths=[2, 2, 2], + num_heads=[3, 6, 12], + window_size=5, + mlp_ratio=2., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + img_range=1., + resi_connection='1conv', + bottleneck_depth=2, + bottleneck_heads=24, + conv_downsample_first=True, + out_chans=1, + no_residual_learning=False, + **kwargs): + super(HUMUSNet, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans if out_chans is None else out_chans + + self.center_slice_out = (out_chans == 1) + + self.img_range = img_range + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + self.conv_downsample_first = conv_downsample_first + self.no_residual_learning = no_residual_learning + + ##################################################################################################### + ################################### 1, input block ################################### + input_conv_dim = embed_dim + self.conv_first = nn.Conv2d(num_in_ch, input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** self.num_layers) + self.mlp_ratio = mlp_ratio + + # Downsample for low-res feature extraction + if self.conv_downsample_first: + img_size = [im //2 for im in img_size] + self.conv_down_block = ConvBlock(input_conv_dim // 2, input_conv_dim, 0.0) + self.conv_down = DownsampConvBlock(input_conv_dim, input_conv_dim) + + # split image into non-overlapping patches + if patch_size > 1: + self.patch_embed = PatchEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + if patch_size > 1: + self.patch_unembed = PatchUnEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, out_chans=embed_dim, norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # Build MUST + # encoder + self.layers_down = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** i_layer) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='D', + ) + self.layers_down.append(layer) + + # bottleneck + dim_scaler = (2 ** self.num_layers) + self.layer_bottleneck = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=bottleneck_depth, + num_heads=bottleneck_heads, + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='B', + ) + + # decoder + self.layers_up = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** (self.num_layers - i_layer - 1)) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[(self.num_layers-1-i_layer)], + num_heads=num_heads[(self.num_layers-1-i_layer)], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='U', + ) + self.layers_up.append(layer) + + self.norm_down = norm_layer(self.num_features) + self.norm_up = norm_layer(self.embed_dim) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(input_conv_dim, input_conv_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(input_conv_dim, input_conv_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim, 3, 1, 1)) + + # Upsample if needed + if self.conv_downsample_first: + self.conv_up_block = ConvBlock(input_conv_dim, input_conv_dim // 2, 0.0) + self.conv_up = TransposeConvBlock(input_conv_dim, input_conv_dim // 2) + + ##################################################################################################### + ################################ 3, output block ################################ + self.conv_last = nn.Conv2d(input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, num_out_ch, 3, 1, 1) + self.output_norm = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + + # divisible by window size + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + + # divisible by total downsampling ratio + # this could be done more efficiently by combining the two + total_downsamp = int(2 ** (self.num_layers - 1)) + pad_h = (total_downsamp - h % total_downsamp) % total_downsamp + pad_w = (total_downsamp - w % total_downsamp) % total_downsamp + x = F.pad(x, (0, pad_w, 0, pad_h)) + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + # encode + skip_cons = [] + for i, layer in enumerate(self.layers_down): + x, skip = layer(x, x_size=layer.input_resolution) + skip_cons.append(skip) + x = self.norm_down(x) # B L C + + # bottleneck + x = self.layer_bottleneck(x, self.layer_bottleneck.input_resolution) + + # decode + for i, layer in enumerate(self.layers_up): + x = layer(x, x_size=layer.input_resolution, skip=skip_cons[-i-1]) + x = self.norm_up(x) + x = self.patch_unembed(x, x_size) + return x + + def forward(self, x): + # print("input shape:", x.shape) + # print("input range:", x.max(), x.min()) + C, H, W = x.shape[1:] + center_slice = (C - 1) // 2 + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.conv_downsample_first: + x_first = self.conv_first(x) + x_down = self.conv_down(self.conv_down_block(x_first)) + res = self.conv_after_body(self.forward_features(x_down)) + res = self.conv_up(res) + res = torch.cat([res, x_first], dim=1) + res = self.conv_up_block(res) + + res = self.conv_last(res) + + if self.no_residual_learning: + x = res + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + res + else: + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + + if self.no_residual_learning: + x = self.conv_last(res) + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + self.conv_last(res) + + # print("output range before scaling:", x.max(), x.min()) + # print("self.mean:", self.mean) + + if self.center_slice_out: + x = x / self.img_range + self.mean[:, center_slice, ...].unsqueeze(1) + else: + x = x / self.img_range + self.mean + + x = self.output_norm(x) + + return x[:, :, :H, :W] + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + +class DownsampConvBlock(nn.Module): + """ + A Downsampling Convolutional Block that consists of one strided convolution + layer followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.Conv2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H/2, W/2)`. + """ + return self.layers(image) + + + +def build_model(args): + return HUMUSNet(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/kspace_mUnet_AttnFusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/kspace_mUnet_AttnFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..02588c24559c0604ddbbd03a14b3b32e6cff2c8d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/kspace_mUnet_AttnFusion.py @@ -0,0 +1,300 @@ +""" +Xiaohan Xing, 2023/11/22 +两个模态分别用CNN提取多个层级的特征, 每个层级都把特征沿着宽度拼接起来,得到(h, 2w, c)的特征。 +然后学习得到(h, 2w)的attention map, 和原始的特征相乘送到后面的层。 +""" + +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.attn_layers = nn.ModuleList() + + for l in range(self.num_pool_layers): + + self.attn_layers.append(nn.Sequential(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob), + nn.Conv2d(chans*(2**l), 2, kernel_size=1, stride=1), + nn.Sigmoid())) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, attn_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.attn_layers): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat之后求attention, 然后进行modulation. + attn = attn_layer(torch.cat((output1, output2), 1)) + # print("spatial attention:", attn[:, 0, :, :].unsqueeze(1).shape) + # print("output1:", output1.shape) + output1 = output1 * (1 + attn[:, 0, :, :].unsqueeze(1)) + output2 = output2 * (1 + attn[:, 1, :, :].unsqueeze(1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/kspace_mUnet_concat.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/kspace_mUnet_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..d16bfdc83305d3e1f490e5ca4b445c6d730557ce --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/kspace_mUnet_concat.py @@ -0,0 +1,351 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Predict the low-frequency mask ############## +class Mask_Predictor(nn.Module): + def __init__(self, in_chans, out_chans, chans, drop_prob): + super(Mask_Predictor, self).__init__() + self.conv1 = ConvBlock(in_chans, chans, drop_prob) + self.conv2 = ConvBlock(chans, chans*2, drop_prob) + self.conv3 = ConvBlock(chans*2, chans*4, drop_prob) + + # self.conv4 = ConvBlock(chans*4, 1, drop_prob) + + self.FC = nn.Linear(chans*4, out_chans) + self.act = nn.Sigmoid() + + + def forward(self, x): + ### 三层卷积和down-sampling. + # print("[Mask predictor] input data range:", x.max(), x.min()) + output = F.avg_pool2d(self.conv1(x), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer1_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv2(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer2_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv3(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer3_feature range:", output.max(), output.min()) + + feature = F.adaptive_avg_pool2d(F.relu(output), (1, 1)).squeeze(-1).squeeze(-1) + # print("mask_prediction feature:", output[:, :5]) + # print("[Mask predictor] bottleneck feature range:", feature.max(), feature.min()) + + output = self.FC(feature) + # print("output before act:", output) + output = self.act(output) + # print("mask output:", output) + + return output + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.mask_predictor1 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + self.mask_predictor2 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + ### 预测做low-freq DC的mask区域. + ### 输出三个维度, 前两维是坐标,最后一维是mask内部的权重 + t1_mask = self.mask_predictor1(output1) + t2_mask = self.mask_predictor2(output2) + + # print("t1_mask and t2_mask:", t1_mask.shape, t2_mask.shape) + t1_mask_coords, t2_mask_coords = t1_mask[:, :2], t2_mask[:, :2] + t1_DC_weight, t2_DC_weight = t1_mask[:, -1], t2_mask[:, -1] + + + # ### 预测做low-freq DC的DC weight + # t1_DC_weight = self.mask_predictor1(output1) + # t2_DC_weight = self.mask_predictor2(output2) + + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, t1_mask_coords, t2_mask_coords, t1_DC_weight, t2_DC_weight ## + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mARTUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9607d41ca14562a94e542a1804af5aeaf15bc123 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mARTUnet.py @@ -0,0 +1,563 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +2023/07/20, Xiaohan Xing +将两个模态的图像在输入端concat, 得到num_channels = 2的input, 然后送到ART模型进行T2 modality的重建。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + # print("feature after layer_norm:", x.max(), x.min()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=2, + out_channels=1, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img, aux_img): + stack = [] + bs, _, H, W = inp_img.shape + fuse_img = torch.cat((inp_img, aux_img), 1) + inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=2, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mUnet_ARTfusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..610b1e1e38d74f695a3a19f22a6a80f3152bf713 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mUnet_ARTfusion.py @@ -0,0 +1,305 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import CS_TransformerBlock as TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + # output1, output2 = image1, image2 + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mUnet_ARTfusion_SeqConcat.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mUnet_ARTfusion_SeqConcat.py new file mode 100644 index 0000000000000000000000000000000000000000..55280925415f1db47cf746b59fb9710d47c9bff2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mUnet_ARTfusion_SeqConcat.py @@ -0,0 +1,374 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any + +import matplotlib.pyplot as plt +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet_new import ConcatTransformerBlock, TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion_SeqConcat(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4. + ): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.vis_feature_maps = False + self.vis_attention_maps = False + # self.vis_feature_maps = args.MODEL.VIS_FEAT_MAPS + # self.vis_attention_maps = args.MODEL.VIS_ATTN_MAPS + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + if l < 2: + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * (2 ** (l + 1))), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + else: + self.fuse_layers.append(nn.ModuleList([ConcatTransformerBlock(dim=int(chans * (2 ** l)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + feature_maps = { + 'modal1': {'before': [], 'after': []}, + 'modal2': {'before': [], 'after': []} + } + + attention_maps = [] + """ + attention_maps = [ + unet layer 1 attention maps (x2): [map from TransBlock1, TransBlock2, ...], + unet layer 2 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 3 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 4 attention maps (x6): [map from TransBlock1, TransBlock2, ...], + ] + """ + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + if self.vis_feature_maps: + feature_maps['modal1']['before'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['before'].append(output2.detach().cpu().numpy()) + + ### 两个模态的特征concat + # output = torch.cat((output1, output2), 1) + + + + unet_layer_attn_maps = [] + """ + unet_layer_attn_maps: [ + attn_map from TransformerBlock1: (B, nHGroups, nWGroups, winSize, winSize), + attn_map from TransformerBlock2: (B, nHGroups, nWGroups, winSize, winSize), + ... + ] + """ + if l < 2: + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + else: + output1 = rearrange(output1, "b c h w -> b (h w) c").contiguous() + output2 = rearrange(output2, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + if l < 2: + if self.vis_attention_maps: + output, attn_map = layer(output, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output = layer(output, [H // (2**l), W // (2**l)]) + + else: + if self.vis_attention_maps: + output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + attention_maps.append(unet_layer_attn_maps) + + if l < 2: + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + else: + output1 = rearrange(output1, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output2 = rearrange(output2, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + # if self.vis_attention_maps: + # output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) # attn_map: (B, nHGroups, nWGroups, winSize, winSize) + # unet_layer_attn_maps.append(attn_map) + # else: + # output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + # attention_maps.append(unet_layer_attn_maps) + + + + if self.vis_feature_maps: + feature_maps['modal1']['after'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['after'].append(output2.detach().cpu().numpy()) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + if self.vis_feature_maps: + return output1, output2, feature_maps#, relation_stack1, relation_stack2 #, t1_features, t2_features + elif self.vis_attention_maps: + return output1, output2, attention_maps + else: + return output1, output2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion_SeqConcat(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mUnet_Restormer_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mUnet_Restormer_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8ab29b641fa0231ad0d163224a4a90b44c32a9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mUnet_Restormer_fusion.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .restormer import TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8]): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.Sequential(*[TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[0], ffn_expansion_factor=2.66, \ + bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + ### 将encoder中multi-modal fusion之后的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = fuse_layer(output) + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mUnet_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mUnet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2a209758203b198502154b75496333ef91717a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mUnet_transformer.py @@ -0,0 +1,387 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/06/08 +分别用CNN提取两个模态的特征,然后送到transformer进行特征交互, 将经过transformer提取的特征和之前CNN提取的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Decoder_wo_skip(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder_wo_skip, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + output = transpose_conv(output) + output = conv(output) + + return output + + + + + +class mUnet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch*2 + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1, output2 = image1, image2 + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack1, output1 = self.encoder1(output1) ### output size: [4, 256, 15, 15] + stack2, output2 = self.encoder2(output2) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output1) # [4, 225, 256], 225 is the number of patches. + patch_embed_input1 = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + n_embed = self.to_patch_embedding(output2) # [4, 225, 256], 225 is the number of patches. + patch_embed_input2 = F.normalize(n_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input1, patch_embed_input2), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1]/2)), int(np.sqrt(feature_output.shape[1]/2)) + patch_num = feature_output.shape[1] + feature_output1 = feature_output[:, :patch_num//2, :] + feature_output2 = feature_output[:, patch_num//2:, :] + feature_output1 = feature_output1.contiguous().view(b, feature_output1.shape[-1], h, w) # [4,c,15,15] + feature_output2 = feature_output2.contiguous().view(b, feature_output2.shape[-1], h, w) + + output = torch.cat((output1, output2), 1) + torch.cat((feature_output1, feature_output2), 1) + # print("CNN feature range:", output1.max(), output1.min()) + # print("Transformer feature range:", feature_output1.max(), feature_output1.min()) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_Transformer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ee0c4784d04638df4ca259613777dde34980a2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet.py @@ -0,0 +1,201 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/04 +Concatenate the input images from different modalities and input it into the UNet. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_ART.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..1a09f2900a8924cdc1a7c373270e6ff8aadcfa45 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_ART.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_ART_v2.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_ART_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..adfe6bc0d9192a126f39932b40443badb5b60068 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_ART_v2.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_early_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_early_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..d062f057a2080a471005125a00ee27453a4b0706 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_early_fusion.py @@ -0,0 +1,223 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +encapsulate the encoder and decoder for the Unet model. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + stack, output = self.encoder(output) + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + output = self.decoder(output, stack) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_late_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..2b48e7d4445afd6a5bc1eaa2b065c37431585e94 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_late_fusion.py @@ -0,0 +1,229 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +use Unet for the reconstruction of each modality, concat the intermediate features of both modalities. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("decoder output:", output.shape) + + return output + + + + +class mmUnet_late(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack1, output1 = self.encoder1(image) + stack2, output2 = self.encoder2(aux_image) + ### fuse two modalities to get multi-modal representation. + output = torch.cat([output1, output2], 1) + output = self.conv(output) ### from [4, 512, 15, 15] to [4, 512, 15, 15] + + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet_late(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..88c7920d65c58ab82a00309a1ae501a8d5b33804 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_restormer.py @@ -0,0 +1,237 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_restormer_3blocks.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_restormer_3blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e3afe68612727d3195872a121ec8040fb0f66ef2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_restormer_3blocks.py @@ -0,0 +1,235 @@ +""" +2023/07/18, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +之前 3 blocks的模型深层特征存在严重的gradient vanishing问题。现在把模型的encoder和decoder都改成只有3个blocks, 看是否还存在gradient vanishing问题。 +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 3, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2] + heads=[1, 2, 4] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_restormer_v2.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..4b51d3a218057206863230973d557173e891d3cd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mmunet_restormer_v2.py @@ -0,0 +1,220 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +这个版本中是将Unet encoder提取的特征直接连接到decoder, 经过restormer refine的特征只是传递到下一层的encoder block. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + output = feature + + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mtrans_mca.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mtrans_mca.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4c89df4af71e986fa7c24d522e6732371f3128 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mtrans_mca.py @@ -0,0 +1,119 @@ +import torch +from torch import nn + +from .mtrans_transformer import Mlp + +# implement by YunluYan + + +class LinearProject(nn.Module): + def __init__(self, dim_in, dim_out, fn): + super(LinearProject, self).__init__() + need_projection = dim_in != dim_out + self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity() + self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity() + self.fn = fn + + def forward(self, x, *args, **kwargs): + x = self.project_in(x) + x = self.fn(x, *args, **kwargs) + x = self.project_out(x) + return x + + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, dim, num_heads, attn_drop=0., proj_drop=0.): + super(MultiHeadCrossAttention, self).__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim, dim*2, bias=False) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + + def forward(self, x, complement): + + # x [B, HW, C] + B_x, N_x, C_x = x.shape + + x_copy = x + + complement = torch.cat([x, complement], 1) + + B_c, N_c, C_c = complement.shape + + # q [B, heads, HW, C//num_heads] + q = self.to_q(x).reshape(B_x, N_x, self.num_heads, C_x//self.num_heads).permute(0, 2, 1, 3) + kv = self.to_kv(complement).reshape(B_c, N_c, 2, self.num_heads, C_c//self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_x, N_x, C_x) + + x = x + x_copy + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossTransformerEncoderLayer(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio = 1., attn_drop=0., proj_drop=0.,drop_path = 0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super(CrossTransformerEncoderLayer, self).__init__() + self.x_norm1 = norm_layer(dim) + self.c_norm1 = norm_layer(dim) + + self.attn = MultiHeadCrossAttention(dim, num_heads, attn_drop, proj_drop) + + self.x_norm2 = norm_layer(dim) + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop) + + self.drop1 = nn.Dropout(drop_path) + self.drop2 = nn.Dropout(drop_path) + + def forward(self, x, complement): + x = self.x_norm1(x) + complement = self.c_norm1(complement) + + x = x + self.drop1(self.attn(x, complement)) + x = x + self.drop2(self.mlp(self.x_norm2(x))) + return x + + + +class CrossTransformer(nn.Module): + def __init__(self, x_dim, c_dim, depth, num_heads, mlp_ratio =1., attn_drop=0., proj_drop=0., drop_path =0.): + super(CrossTransformer, self).__init__() + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + LinearProject(x_dim, c_dim, CrossTransformerEncoderLayer(c_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)), + LinearProject(c_dim, x_dim, CrossTransformerEncoderLayer(x_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)) + ])) + + def forward(self, x, complement): + for x_attn_complemnt, complement_attn_x in self.layers: + x = x_attn_complemnt(x, complement=complement) + x + complement = complement_attn_x(complement, complement=x) + complement + return x, complement + + + + + + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mtrans_net.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mtrans_net.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e80dadb45725603c7ff1da664a45533575db25 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mtrans_net.py @@ -0,0 +1,145 @@ +""" +This script implement the paper "Multi-Modal Transformer for Accelerated MR Imaging (TMI 2022)" +""" + + +import torch + +from torch import nn +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler +from .mtrans_mca import CrossTransformer + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(INPUT_DIM, HEAD_HIDDEN_DIM): + return ReconstructionHead(INPUT_DIM, HEAD_HIDDEN_DIM) + + +class CrossCMMT(nn.Module): + + def __init__(self, args): + super(CrossCMMT, self).__init__() + + INPUT_SIZE = 240 + INPUT_DIM = 1 # the channel of input + OUTPUT_DIM = 1 # the channel of output + HEAD_HIDDEN_DIM = 16 # the hidden dim of Head + TRANSFORMER_DEPTH = 4 # the depth of the transformer + TRANSFORMER_NUM_HEADS = 4 # the head's num of multi head attention + TRANSFORMER_MLP_RATIO = 3 # the MLP RATIO Of transformer + TRANSFORMER_EMBED_DIM = 256 # the EMBED DIM of transformer + P1 = 8 + P2 = 16 + CTDEPTH = 4 + + + self.head = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + self.head2 = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + + x_patch_dim = HEAD_HIDDEN_DIM * P1 ** 2 + x_num_patches = (INPUT_SIZE // P1) ** 2 + + complement_patch_dim = HEAD_HIDDEN_DIM * P2 ** 2 + complement_num_patches = (INPUT_SIZE // P2) ** 2 + + self.x_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P1, + p2=P1), + ) + + self.complement_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P2, + p2=P2), + ) + + self.x_pos_embedding = nn.Parameter(torch.randn(1, x_num_patches, x_patch_dim)) + self.complement_pos_embedding = nn.Parameter(torch.randn(1, complement_num_patches, complement_patch_dim)) + + ### + self.cross_transformer = CrossTransformer(x_patch_dim, complement_patch_dim, CTDEPTH, + TRANSFORMER_NUM_HEADS, TRANSFORMER_MLP_RATIO) + + self.p1 = P1 + self.p2 = P2 + + self.tail1 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + self.tail2 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x, complement): + x = self.head(x) + complement = self.head2(complement) + + x = x + complement + + b, _, h, w = x.shape + + _, _, c_h, c_w = complement.shape + + ### 得到每个模态的patch embedding (加上了position encoding) + x = self.x_patch_embbeding(x) + x += self.x_pos_embedding + + complement = self.complement_patch_embbeding(complement) + complement += self.complement_pos_embedding + + ### 将两个模态的patch embeddings送到cross transformer中提取特征。 + x, complement = self.cross_transformer(x, complement) + + c = int(x.shape[2] / (self.p1 * self.p1)) + H = int(h / self.p1) + W = int(w / self.p1) + + x = x.reshape(b, H, W, self.p1, self.p1, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w, ) + x = self.tail1(x) + + complement = complement.reshape(b, int(c_h/self.p2), int(c_w/self.p2), self.p2, self.p2, int(complement.shape[2]/self.p2/self.p2)) + complement = complement.permute(0, 5, 1, 3, 2, 4) + complement = complement.reshape(b, -1, c_h, c_w) + + complement = self.tail2(complement) + + return x, complement + + +def build_model(args): + return CrossCMMT(args) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mtrans_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mtrans_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..61314a6a81247b0740a1e98d078242b26ce73f90 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/mtrans_transformer.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(args, embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=args.MODEL.TRANSFORMER_DEPTH, + num_heads=args.MODEL.TRANSFORMER_NUM_HEADS, + mlp_ratio=args.MODEL.TRANSFORMER_MLP_RATIO) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/munet_concat_decomp.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/munet_concat_decomp.py new file mode 100644 index 0000000000000000000000000000000000000000..07c6f64804dd586927814801b167163f0b65f58f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/munet_concat_decomp.py @@ -0,0 +1,295 @@ +""" +Xiaohan Xing, 2023/06/25 +两个模态分别用CNN提取多个层级的特征, 每个层级的特征都经过concat之后得到fused feature, +然后每个模态都通过一层conv变换得到2C*h*w的特征, 其中C个channels作为common feature, 另C个channels作为specific features. +利用CDDFuse论文中的decomposition loss约束特征解耦。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.transform_layers1 = nn.ModuleList() + self.transform_layers2 = nn.ModuleList() + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.transform_layers1.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.transform_layers2.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + stack1_common, stack1_specific, stack2_common, stack2_specific = [], [], [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_transform_layer, net2_transform_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.transform_layers1, self.transform_layers2, \ + self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + # print("original feature:", output1.shape) + + ### 分别提取两个模态的common feature和specific feature + output1 = net1_transform_layer(output1) + # print("transformed feature:", output1.shape) + output1_common = output1[:, :(output1.shape[1]//2), :, :] + output1_specific = output1[:, (output1.shape[1]//2):, :, :] + output2 = net2_transform_layer(output2) + output2_common = output2[:, :(output2.shape[1]//2), :, :] + output2_specific = output2[:, (output2.shape[1]//2):, :, :] + + stack1_common.append(output1_common) + stack1_specific.append(output1_specific) + stack2_common.append(output2_common) + stack2_specific.append(output2_specific) + + ### 将一个模态的两组feature和另一模态的common feature合并, 然后经过conv变换得到和初始特征相同维度的fused feature送到下一层。 + output1 = net1_fuse_layer(torch.cat((output1_common, output1_specific, output2_common), 1)) + output2 = net2_fuse_layer(torch.cat((output2_common, output2_specific, output1_common), 1)) + # print("fused feature:", output1.shape) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, stack1_common, stack1_specific, stack2_common, stack2_specific + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/munet_multi_concat.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/munet_multi_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..9e90373e1134210809b98fdc64e3d51d76123855 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/munet_multi_concat.py @@ -0,0 +1,306 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + # output1, output2 = image1, image2 + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # relation1 = get_relation_matrix(output1) + # relation2 = get_relation_matrix(output2) + # relation_stack1.append(relation1) + # relation_stack2.append(relation2) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + # output1 = net1_layer(output1) + # output2 = net2_layer(output2) + + # ### 两个模态的特征concat + # fused_output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + # fused_output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # stack1.append(fused_output1) + # stack2.append(fused_output2) + + # ### downsampling features + # output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + # output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/munet_multi_sum.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/munet_multi_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb4efb13b54ccbeaeb34fc51d3e72b0c50ee242 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/munet_multi_sum.py @@ -0,0 +1,270 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers): + + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + + ## 两个模态的特征求平均. + fused_output1 = (output1 + output2)/2.0 + fused_output2 = (output1 + output2)/2.0 + + output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features #, relation_stack1, relation_stack2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/munet_multi_transfuse.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/munet_multi_transfuse.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ce250a078ba0847c729c1ef1b76452f64d0f07 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/munet_multi_transfuse.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +2023/06/23 +每个层级的特征用TransFuse layer融合之后, 将两个模态的original features和transfuse features concat, +然后在每个模态中经过conv层变换得到下一层的输入。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .TransFuse import TransFuse_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_TransFuse(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + ### conv layer to transform the features after Transfuse layer. + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + self.n_anchors = 15 + self.avgpool = nn.AdaptiveAvgPool2d((self.n_anchors, self.n_anchors)) + self.transfuse_layer1 = TransFuse_layer(n_embd=self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer2 = TransFuse_layer(n_embd=2*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer3 = TransFuse_layer(n_embd=4*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer4 = TransFuse_layer(n_embd=8*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + + self.transfuse_layers = nn.Sequential(self.transfuse_layer1, self.transfuse_layer2, self.transfuse_layer3, self.transfuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, transfuse_layer, net1_fuse_layer, net2_fuse_layer in zip(self.encoder1.down_sample_layers, \ + self.encoder2.down_sample_layers, self.transfuse_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### extract multi-level features from the encoder of each modality. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + ### Transformer-based multi-modal fusion layer + feature1 = self.avgpool(output1) + feature2 = self.avgpool(output2) + feature1, feature2 = transfuse_layer(feature1, feature2) + feature1 = F.interpolate(feature1, scale_factor=output1.shape[-1]//self.n_anchors, mode='bilinear') + feature2 = F.interpolate(feature2, scale_factor=output2.shape[-1]//self.n_anchors, mode='bilinear') + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_TransFuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/munet_swinfusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/munet_swinfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6203c4331744f1da70f18e403360196a419b4cb5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/munet_swinfusion.py @@ -0,0 +1,268 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .SwinFuse_layer import SwinFusion_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_SwinFuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过SwinFusion的方式融合。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.fuse_type = 'swinfuse' + self.img_size = 240 + # self.fuse_type = 'sum' + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # Swin transformer fusion modules + + self.fuse_layer1 = SwinFusion_layer(img_size=(self.img_size//2, self.img_size//2), patch_size=4, window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer2 = SwinFusion_layer(img_size=(self.img_size//4, self.img_size//4), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=2*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer3 = SwinFusion_layer(img_size=(self.img_size//8, self.img_size//8), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=4*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer4 = SwinFusion_layer(img_size=(self.img_size//16, self.img_size//16), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=8*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.swinfusion_layers = nn.Sequential(self.fuse_layer1, self.fuse_layer2, self.fuse_layer3, self.fuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, swinfuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.swinfusion_layers): + output1 = net1_layer(output1) + stack1.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # print("output of encoders:", output1.shape, output2.shape) + # print("CNN feature range:", output1.max(), output2.min()) + + ### Swin transformer-based multi-modal fusion. + feature1, feature2 = swinfuse_layer(output1, output2) + output1 = (output1 + feature1 + output2 + feature2)/4.0 + output2 = (output1 + feature1 + output2 + feature2)/4.0 + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_SwinFuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/original_MINet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/original_MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..5a115872320f677d8b34264eed95c22fff0df9c4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/original_MINet.py @@ -0,0 +1,335 @@ +from fastmri.models import common +import torch +import torch.nn as nn +import torch.nn.functional as F + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=common.default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 2 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + common.Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.add_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class MINet(nn.Module): + def __init__(self, n_resgroups,n_resblocks, n_feats): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1,x2#x1=pd x2=pdfs + \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..3699cc809e438be4e1bd5efaba7d96e18d7f1602 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/restormer.py @@ -0,0 +1,307 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + # print("feature before FFN projection:", x.max(), x.min()) + x = self.project_out(x) + # print("FFN output:", x.max(), x.min()) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + b,c,h,w = x.shape + # print("Attention block input:", x.max(), x.min()) + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + # print("attention values:", attn) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("Attention block output:", out.max(), out.min()) + + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): + super(TransformerBlock, self).__init__() + + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + # dim = 32, + # num_blocks = [4,4,4,4], + dim = 32, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + + # stack = [] + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + # print("encoder block1 feature:", out_enc_level1.shape, out_enc_level1.max(), out_enc_level1.min()) + # stack.append(out_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + # print("encoder block2 feature:", out_enc_level2.shape, out_enc_level2.max(), out_enc_level2.min()) + # stack.append(out_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + # print("encoder block3 feature:", out_enc_level3.shape, out_enc_level3.max(), out_enc_level3.min()) + # stack.append(out_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + # print("encoder block4 feature:", latent.shape, latent.max(), latent.min()) + # stack.append(latent) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + + return out_dec_level1 #, stack + + + +def build_model(args): + return Restormer() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/restormer_block.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/restormer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..170ae6a826e5a3e356cb48f8c89500c2d5242816 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/restormer_block.py @@ -0,0 +1,321 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + # print("input to the LayerNorm layer:", x.max(), x.min()) + h, w = x.shape[-2:] + x = to_4d(self.body(to_3d(x)), h, w) + # print("output of the LayerNorm:", x.max(), x.min()) + return x + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + # self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.temperature = 1.0 / ((dim//num_heads) ** 0.5) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + # print("input to the Attention layer:", x.max(), x.min()) + + b,c,h,w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + # print("q range:", q.max(), q.min()) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + # print("scaling parameter:", self.temperature) + # print("attention matrix before softmax:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + attn = attn.softmax(dim=-1) + # print("attention matrix range:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + # print("v range:", v.max(), v.min(), v.mean()) + + out = (attn @ v) + # print("self-attention output:", out.max(), out.min()) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("output of attention layer:", out.max(), out.min()) + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, pre_norm): + super(TransformerBlock, self).__init__() + + self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + self.pre_norm = pre_norm + + def forward(self, x): + x = self.conv(x) + # print("restormer block input:", x.max(), x.min()) + if self.pre_norm: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + else: + x = self.norm1(x + self.attn(x)) + x = self.norm2(x + self.ffn(x)) + # x = self.norm1(self.attn(x)) + # x = self.norm2(self.ffn(x)) + # x = x + self.attn(x) + # x = x + self.ffn(x) + # print("restormer block output:", x.max(), x.min()) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + dim = 48, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, inp_img): + + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + + + return out_dec_level1 + + + + +if __name__ == '__main__': + model = Restormer() + # print(model) + + A = torch.randn((1, 1, 240, 240)) + x = model(A) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/swinIR.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/swinIR.py new file mode 100644 index 0000000000000000000000000000000000000000..5cbfc9c175c02531ac80a05ba4abfc0776f752dd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/swinIR.py @@ -0,0 +1,874 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 4, 4, 6], + embed_dim=32, num_heads=[6, 6, 6, 6], mlp_ratio=4, upsampler='pixelshuffledirect') + + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = SwinIR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/swin_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..877bdfbffc91012e4ac31f41b838ebb116f4a825 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/swin_transformer.py @@ -0,0 +1,878 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=1, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + # print("shallow feature:", x.shape) + x = self.conv_after_body(self.forward_features(x)) + x + # print("deep feature:", x.shape) + x = self.upsample(x) + # print("after upsample:", x.shape) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + # return SwinIR(upscale=1, img_size=(240, 240), + # window_size=8, img_range=1., depths=[6, 6, 6, 6], + # embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + +if __name__ == '__main__': + height = 240 + width = 240 + window_size = 8 + model = SwinIR(upscale=1, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + x = torch.randn((1, 1, height, width)) + print("input:", x.shape) + x = model(x) + print("output:", x.shape) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/trans_unet.zip b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/trans_unet.zip new file mode 100644 index 0000000000000000000000000000000000000000..e7b71c1287551fd6119efd7c62da2196307998f5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/trans_unet.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:369de2a28036532c446409ee1f7b953d5763641ffd452675613e1bb9bba89eae +size 19354 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/trans_unet/trans_unet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/trans_unet/trans_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..fa61dbadfcaee9b26594307adc90cdd1d2a84e3e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/trans_unet/trans_unet.py @@ -0,0 +1,489 @@ +# coding=utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import logging +import math + +from os.path import join as pjoin + +import torch +import torch.nn as nn +import numpy as np + +from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm +from torch.nn.modules.utils import _pair +from scipy import ndimage +from . import vit_seg_configs as configs +# import .vit_seg_configs as configs +from .vit_seg_modeling_resnet_skip import ResNetV2 + + +logger = logging.getLogger(__name__) + + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class Attention(nn.Module): + def __init__(self, config, vis): + super(Attention, self).__init__() + self.vis = vis + self.num_attention_heads = config.transformer["num_heads"] + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(config.hidden_size, self.all_head_size) + self.key = Linear(config.hidden_size, self.all_head_size) + self.value = Linear(config.hidden_size, self.all_head_size) + + self.out = Linear(config.hidden_size, config.hidden_size) + self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) + self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + + +class Mlp(nn.Module): + def __init__(self, config): + super(Mlp, self).__init__() + self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) + self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(config.transformer["dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings. + """ + def __init__(self, config, img_size, in_channels=3): + super(Embeddings, self).__init__() + self.hybrid = None + self.config = config + img_size = _pair(img_size) + + if config.patches.get("grid") is not None: # ResNet + grid_size = config.patches["grid"] + # print("image size:", img_size, "grid size:", grid_size) + patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) + patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) + # print("patch_size_real", patch_size_real) + n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) + self.hybrid = True + else: + patch_size = _pair(config.patches["size"]) + n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) + self.hybrid = False + + if self.hybrid: + self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) + in_channels = self.hybrid_model.width * 16 + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=config.hidden_size, + kernel_size=patch_size, + stride=patch_size) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) + + self.dropout = Dropout(config.transformer["dropout_rate"]) + + + def forward(self, x): + # print("input x:", x.shape) ### (4, 3, 240, 240) + if self.hybrid: + x, features = self.hybrid_model(x) + else: + features = None + # print("output of the hybrid model:", x.shape) ### (4, 1024, 15, 15) + x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) + x = x.flatten(2) + x = x.transpose(-1, -2) # (B, n_patches, hidden) + + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings, features + + +class Block(nn.Module): + def __init__(self, config, vis): + super(Block, self).__init__() + self.hidden_size = config.hidden_size + self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn = Mlp(config) + self.attn = Attention(config, vis) + + def forward(self, x): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + def load_from(self, weights, n_block): + ROOT = f"Transformer/encoderblock_{n_block}" + with torch.no_grad(): + query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() + key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() + value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() + out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() + + query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) + key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) + value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) + out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) + + self.attn.query.weight.copy_(query_weight) + self.attn.key.weight.copy_(key_weight) + self.attn.value.weight.copy_(value_weight) + self.attn.out.weight.copy_(out_weight) + self.attn.query.bias.copy_(query_bias) + self.attn.key.bias.copy_(key_bias) + self.attn.value.bias.copy_(value_bias) + self.attn.out.bias.copy_(out_bias) + + mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() + mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() + mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() + mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() + + self.ffn.fc1.weight.copy_(mlp_weight_0) + self.ffn.fc2.weight.copy_(mlp_weight_1) + self.ffn.fc1.bias.copy_(mlp_bias_0) + self.ffn.fc2.bias.copy_(mlp_bias_1) + + self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) + self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) + self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) + self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) + + +class Encoder(nn.Module): + def __init__(self, config, vis): + super(Encoder, self).__init__() + self.vis = vis + self.layer = nn.ModuleList() + self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) + for _ in range(config.transformer["num_layers"]): + layer = Block(config, vis) + self.layer.append(copy.deepcopy(layer)) + + def forward(self, hidden_states): + attn_weights = [] + for layer_block in self.layer: + hidden_states, weights = layer_block(hidden_states) + if self.vis: + attn_weights.append(weights) + encoded = self.encoder_norm(hidden_states) + return encoded, attn_weights + + +class Transformer(nn.Module): + def __init__(self, config, img_size, vis): + super(Transformer, self).__init__() + self.embeddings = Embeddings(config, img_size=img_size) + self.encoder = Encoder(config, vis) + # print(self.encoder) + + def forward(self, input_ids): + embedding_output, features = self.embeddings(input_ids) ### "features" store the features from different layers. + # print("embedding_output:", embedding_output.shape) ## (B, n_patch, hidden) + encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) + # print("encoded feature:", encoded.shape) + return encoded, attn_weights, features + + # return embedding_output, features + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + bn = nn.BatchNorm2d(out_channels) + + super(Conv2dReLU, self).__init__(conv, bn, relu) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels=0, + use_batchnorm=True, + ): + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + + def forward(self, x, skip=None): + x = self.up(x) + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class ReconstructionHead(nn.Sequential): + + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() + super().__init__(conv2d, upsampling) + + +class DecoderCup(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + head_channels = 512 + self.conv_more = Conv2dReLU( + config.hidden_size, + head_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + decoder_channels = config.decoder_channels + in_channels = [head_channels] + list(decoder_channels[:-1]) + out_channels = decoder_channels + + if self.config.n_skip != 0: + skip_channels = self.config.skip_channels + for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip + skip_channels[3-i]=0 + + else: + skip_channels=[0,0,0,0] + + blocks = [ + DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) + ] + self.blocks = nn.ModuleList(blocks) + # print(self.blocks) + + def forward(self, hidden_states, features=None): + B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) + h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) + x = hidden_states.permute(0, 2, 1) + x = x.contiguous().view(B, hidden, h, w) + x = self.conv_more(x) + for i, decoder_block in enumerate(self.blocks): + if features is not None: + skip = features[i] if (i < self.config.n_skip) else None + else: + skip = None + x = decoder_block(x, skip=skip) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): + super(VisionTransformer, self).__init__() + self.num_classes = num_classes + self.zero_head = zero_head + self.classifier = config.classifier + self.transformer = Transformer(config, img_size, vis) + self.decoder = DecoderCup(config) + self.segmentation_head = ReconstructionHead( + in_channels=config['decoder_channels'][-1], + out_channels=config['n_classes'], + kernel_size=3, + ) + self.activation = nn.Tanh() + self.config = config + + def forward(self, x): + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + ### hybrid CNN and transformer to extract features from the input images. + x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) + + # ### extract features with CNN only. + # x, features = self.transformer(x) # (B, n_patch, hidden) + + x = self.decoder(x, features) + # logits = self.segmentation_head(x) + logits = self.activation(self.segmentation_head(x)) + # print("logits:", logits.shape) + return logits + + def load_from(self, weights): + with torch.no_grad(): + + res_weight = weights + self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) + self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) + + self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) + self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) + + posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) + + posemb_new = self.transformer.embeddings.position_embeddings + if posemb.size() == posemb_new.size(): + self.transformer.embeddings.position_embeddings.copy_(posemb) + elif posemb.size()[1]-1 == posemb_new.size()[1]: + posemb = posemb[:, 1:] + self.transformer.embeddings.position_embeddings.copy_(posemb) + else: + logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) + ntok_new = posemb_new.size(1) + if self.classifier == "seg": + _, posemb_grid = posemb[:, :1], posemb[0, 1:] + gs_old = int(np.sqrt(len(posemb_grid))) + gs_new = int(np.sqrt(ntok_new)) + print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb = posemb_grid + self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) + + # Encoder whole + for bname, block in self.transformer.encoder.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=uname) + + if self.transformer.embeddings.hybrid: + self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) + gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) + gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) + self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) + self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) + + for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(res_weight, n_block=bname, n_unit=uname) + +CONFIGS = { + 'ViT-B_16': configs.get_b16_config(), + 'ViT-B_32': configs.get_b32_config(), + 'ViT-L_16': configs.get_l16_config(), + 'ViT-L_32': configs.get_l32_config(), + 'ViT-H_14': configs.get_h14_config(), + 'R50-ViT-B_16': configs.get_r50_b16_config(), + 'R50-ViT-L_16': configs.get_r50_l16_config(), + 'testing': configs.get_testing(), +} + + + +def build_model(args): + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + print(net) + # net.load_from(weights=np.load(config_vit.pretrained_path)) + return net + + +if __name__ == "__main__": + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + net.load_from(weights=np.load(config_vit.pretrained_path)) + + input_data = torch.rand(4, 1, 240, 240).cuda() + print("input:", input_data.shape) + output = net(input_data) + print("output:", output.shape) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/trans_unet/vit_seg_configs.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/trans_unet/vit_seg_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..bed7d3346e30328a38ac6fef1621f788d905df1e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/trans_unet/vit_seg_configs.py @@ -0,0 +1,132 @@ +import ml_collections + +def get_b16_config(): + """Returns the ViT-B/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 768 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 3072 + config.transformer.num_heads = 12 + config.transformer.num_layers = 4 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + + config.classifier = 'seg' + config.representation_size = None + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' + config.patch_size = 16 + + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_testing(): + """Returns a minimal configuration for testing.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 1 + config.transformer.num_heads = 1 + config.transformer.num_layers = 1 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config + + +def get_r50_b16_config(): + """Returns the Resnet50 + ViT-B/16 configuration.""" + config = get_b16_config() + # config.patches.grid = (16, 16) + config.patches.grid = (15, 15) ### for the input image with size (240, 240) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'recon' + config.pretrained_path = '/home/xiaohan/workspace/MSL_MRI/code/pretrained_model/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 1 + config.n_skip = 3 + config.activation = 'tanh' + + return config + + +def get_b32_config(): + """Returns the ViT-B/32 configuration.""" + config = get_b16_config() + config.patches.size = (32, 32) + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' + return config + + +def get_l16_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1024 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 4096 + config.transformer.num_heads = 16 + config.transformer.num_layers = 24 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.representation_size = None + + # custom + config.classifier = 'seg' + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_r50_l16_config(): + """Returns the Resnet50 + ViT-L/16 configuration. customized """ + config = get_l16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_l32_config(): + """Returns the ViT-L/32 configuration.""" + config = get_l16_config() + config.patches.size = (32, 32) + return config + + +def get_h14_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (14, 14)}) + config.hidden_size = 1280 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 5120 + config.transformer.num_heads = 16 + config.transformer.num_layers = 32 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + + return config \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae80ef7d96c46cb91bda849901aef34008e62ed --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py @@ -0,0 +1,160 @@ +import math + +from os.path import join as pjoin +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +class StdConv2d(nn.Conv2d): + + def forward(self, x): + w = self.weight + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / torch.sqrt(v + 1e-5) + return F.conv2d(x, w, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +def conv3x3(cin, cout, stride=1, groups=1, bias=False): + return StdConv2d(cin, cout, kernel_size=3, stride=stride, + padding=1, bias=bias, groups=groups) + + +def conv1x1(cin, cout, stride=1, bias=False): + return StdConv2d(cin, cout, kernel_size=1, stride=stride, + padding=0, bias=bias) + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + """ + + def __init__(self, cin, cout=None, cmid=None, stride=1): + super().__init__() + cout = cout or cin + cmid = cmid or cout//4 + + self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv1 = conv1x1(cin, cmid, bias=False) + self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! + self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) + self.conv3 = conv1x1(cmid, cout, bias=False) + self.relu = nn.ReLU(inplace=True) + + if (stride != 1 or cin != cout): + # Projection also with pre-activation according to paper. + self.downsample = conv1x1(cin, cout, stride, bias=False) + self.gn_proj = nn.GroupNorm(cout, cout) + + def forward(self, x): + + # Residual branch + residual = x + if hasattr(self, 'downsample'): + residual = self.downsample(x) + residual = self.gn_proj(residual) + + # Unit's branch + y = self.relu(self.gn1(self.conv1(x))) + y = self.relu(self.gn2(self.conv2(y))) + y = self.gn3(self.conv3(y)) + + y = self.relu(residual + y) + return y + + def load_from(self, weights, n_block, n_unit): + conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) + conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) + conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) + + gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) + gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) + + gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) + gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) + + gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) + gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) + + self.conv1.weight.copy_(conv1_weight) + self.conv2.weight.copy_(conv2_weight) + self.conv3.weight.copy_(conv3_weight) + + self.gn1.weight.copy_(gn1_weight.view(-1)) + self.gn1.bias.copy_(gn1_bias.view(-1)) + + self.gn2.weight.copy_(gn2_weight.view(-1)) + self.gn2.bias.copy_(gn2_bias.view(-1)) + + self.gn3.weight.copy_(gn3_weight.view(-1)) + self.gn3.bias.copy_(gn3_bias.view(-1)) + + if hasattr(self, 'downsample'): + proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) + proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) + proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) + + self.downsample.weight.copy_(proj_conv_weight) + self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) + self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode.""" + + def __init__(self, block_units, width_factor): + super().__init__() + width = int(64 * width_factor) + self.width = width + + self.root = nn.Sequential(OrderedDict([ + ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), + ('gn', nn.GroupNorm(32, width, eps=1e-6)), + ('relu', nn.ReLU(inplace=True)), + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) + ])) + + self.body = nn.Sequential(OrderedDict([ + ('block1', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], + ))), + ('block2', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], + ))), + ('block3', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], + ))), + ])) + + def forward(self, x): + features = [] + b, c, in_size, _ = x.size() + x = self.root(x) + features.append(x) + x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) + for i in range(len(self.body)-1): + x = self.body[i](x) + right_size = int(in_size / 4 / (i+1)) + if x.size()[2] != right_size: + pad = right_size - x.size()[2] + assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) + feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) + feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] + else: + feat = x + features.append(feat) + x = self.body[-1](x) + return x, features[::-1] \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/transformer_modules.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/transformer_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..900042884305619e39ad9b792b3b1936dfbd37b3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/transformer_modules.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=4, + num_heads=4, + mlp_ratio=3) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/unet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..d6dd935cd34a121d8079a8f5184d4ea29e15c8da --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/unet.py @@ -0,0 +1,196 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F + + +class Unet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = image + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + # print("unet feature range:", output.max(), output.min()) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/unet_restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/unet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f78aa5f192957b2706c6f691dfc66c427b65ae --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/unet_restormer.py @@ -0,0 +1,234 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + # print("Unet feature range:", output.max(), output.min()) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/unet_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/unet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..982fe23cec321de6b636e04c51b62037054ea432 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/unet_transformer.py @@ -0,0 +1,346 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +对于Unet中的high-level feature, 用Transformer进行特征交互,然后和Transformer之前的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Unet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = image + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack, output = self.encoder(output) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output) # [4, 225, 256], 225 is the number of patches. + patch_embed_input = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1])), int(np.sqrt(feature_output.shape[1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[-1], h, w) # [4,512*2,24,24] + + output = torch.cat((output, feature_output), 1) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output = self.decoder(output, stack) + + return output + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Transformer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/unimodal_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/unimodal_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c99214e10bee2f7a1be49f8560799ffbc14ed0a5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/compare_models/unimodal_transformer.py @@ -0,0 +1,112 @@ +import torch + +from torch import nn +from .transformer_modules import build_transformer +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(): + return ReconstructionHead(input_dim=1, hidden_dim=12) + + + +class CMMT(nn.Module): + + def __init__(self, args): + super(CMMT, self).__init__() + + self.head = build_head() + + HEAD_HIDDEN_DIM = 12 + PATCH_SIZE = 16 + INPUT_SIZE = 240 + OUTPUT_DIM = 1 + + patch_dim = HEAD_HIDDEN_DIM* PATCH_SIZE ** 2 + num_patches = (INPUT_SIZE // PATCH_SIZE) ** 2 + + self.transformer = build_transformer(patch_dim) + + self.patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=PATCH_SIZE, p2=PATCH_SIZE), + ) + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, patch_dim)) + + + self.p1 = PATCH_SIZE + self.p2 = PATCH_SIZE + + self.tail = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x): + + b,_ , h, w = x.shape + + x = self.head(x) + + x= self.patch_embbeding(x) + + x += self.pos_embedding + + x = self.transformer(x) # b HW p1p2c + + c = int(x.shape[2]/(self.p1*self.p2)) + H = int(h/self.p1) + W = int(w/self.p2) + + x = x.reshape(b, H, W, self.p1, self.p2, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w,) + x = self.tail(x) + + return x + + +def build_model(args): + return CMMT(args) + + + + + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/modules.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..bad92a6d6709130c4ad1cc49d5094958d44d7e71 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/modules.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + +USE_PYTORCH_IN = False + + +###################################################################### +# Superclass of all Modules that take two inputs +###################################################################### +class TwoInputModule(nn.Module): + def forward(self, input1, input2): + raise NotImplementedError + +###################################################################### +# A (sort of) hacky way to create a module that takes two inputs (e.g. x and z) +# and returns one output (say o) defined as follows: +# o = module2.forward(module1.forward(x), z) +# Note that module2 MUST support two inputs as well. +###################################################################### +class MergeModule(TwoInputModule): + def __init__(self, module1, module2): + """ module1 could be any module (e.g. Sequential of several modules) + module2 must accept two inputs + """ + super(MergeModule, self).__init__() + self.module1 = module1 + self.module2 = module2 + + def forward(self, input1, input2): + output1 = self.module1.forward(input1) + output2 = self.module2.forward(output1, input2) + return output2 + +###################################################################### +# A (sort of) hacky way to create a container that takes two inputs (e.g. x and z) +# and applies a sequence of modules (exactly like nn.Sequential) but MergeModule +# is one of its submodules it applies it to both inputs +###################################################################### +class TwoInputSequential(nn.Sequential, TwoInputModule): + def __init__(self, *args): + super(TwoInputSequential, self).__init__(*args) + + def forward(self, input1, input2): + """overloads forward function in parent calss""" + + for module in self._modules.values(): + if isinstance(module, TwoInputModule): + input1 = module.forward(input1, input2) + else: + input1 = module.forward(input1) + return input1 + + +###################################################################### +# A standard instance norm module. +# Since the pytorch instance norm used BatchNorm as a base and thus is +# different from the standard implementation. +###################################################################### +class InstanceNorm(nn.Module): + def __init__(self, num_features, affine=True, eps=1e-5): + """`num_features` number of feature channels + """ + super(InstanceNorm, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + + self.scale = Parameter(torch.Tensor(num_features)) + self.shift = Parameter(torch.Tensor(num_features)) + + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + self.scale.data.normal_(mean=0., std=0.02) + self.shift.data.zero_() + + def forward(self, input): + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + centered_x = x_reshaped - mean + std = torch.rsqrt((centered_x ** 2).mean(2, keepdim=True) + self.eps) + norm_features = (centered_x * std).view(*size) + + # broadcast on the batch dimension, hight and width dimensions + if self.affine: + output = norm_features * self.scale[:,None,None] + self.shift[:,None,None] + else: + output = norm_features + + return output + +InstanceNorm2d = nn.InstanceNorm2d if USE_PYTORCH_IN else InstanceNorm + +###################################################################### +# A module implementing conditional instance norm. +# Takes two inputs: x (input features) and z (latent codes) +###################################################################### +class CondInstanceNorm(TwoInputModule): + def __init__(self, x_dim, z_dim, eps=1e-5): + """`x_dim` dimensionality of x input + `z_dim` dimensionality of z latents + """ + super(CondInstanceNorm, self).__init__() + self.eps = eps + self.shift_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + self.scale_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + + def forward(self, input, noise): + + shift = self.shift_conv.forward(noise) + scale = self.scale_conv.forward(noise) + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + var = x_reshaped.var(2, keepdim=True) + std = torch.rsqrt(var + self.eps) + norm_features = ((x_reshaped - mean) * std).view(*size) + output = norm_features * scale + shift + return output + + +###################################################################### +# A modified resnet block which allows for passing additional noise input +# to be used for conditional instance norm +###################################################################### +class CINResnetBlock(TwoInputModule): + def __init__(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + super(CINResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + for idx, module in enumerate(self.conv_block): + self.add_module(str(idx), module) + + def build_conv_block(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [ + MergeModule( + nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(x_dim, z_dim) + ), + nn.ReLU(True) + ] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + InstanceNorm2d(x_dim, affine=True)] + + return TwoInputSequential(*conv_block) + + def forward(self, x, noise): + out = self.conv_block(x, noise) + out = self.relu(x + out) + return out + +###################################################################### +# Define a resnet block +###################################################################### +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [nn.ReLU(True)] + + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = self.conv_block(x) + out = self.relu(x + out) + return out diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/mynet.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/mynet.py new file mode 100644 index 0000000000000000000000000000000000000000..042c6b2ca805b7b3772f8f244edeffa812841296 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/networks_time/mynet.py @@ -0,0 +1,467 @@ +import torch, math +from torch import nn +from networks_time import common_freq as common + + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + + +class Block_Sequential(nn.Module): + def __init__(self, block1, block2): + super(Block_Sequential, self).__init__() + self.block1 = block1 + self.block2 = block2 + + def forward(self, x, t=None): + x = self.block1(x) + x = self.block2(x, t) + return x + + +class DiffTwoBranch(nn.Module): + def __init__(self, args): + super(DiffTwoBranch, self).__init__() + + num_group = 4 + num_every_group = args.base_num_every_group + self.args = args + + self.ch = args.num_channels + self.temb_ch = args.num_channels * 4 + args.temb_channels = self.temb_ch + + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(args.num_channels, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ]) + + self.init_T2_frq_branch(args) + self.init_T2_spa_branch(args, num_every_group) + self.init_T2_fre_spa_fusion(args) + + self.init_T1_frq_branch(args) + self.init_T1_spa_branch(args, num_every_group) + + self.init_modality_fre_fusion(args) + self.init_modality_spa_fusion(args) + + + def init_T2_frq_branch(self, args): + ### T2frequency branch + self.head_fre = common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act) + + self.down1_fre = Block_Sequential(*[common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args)]) + + self.down1_fre_mo = common.FreBlock9(args.num_features, args) + + modules_down2_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down2_fre = Block_Sequential(*modules_down2_fre) + + self.down2_fre_mo = common.FreBlock9(args.num_features, args) + + modules_down3_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down3_fre = Block_Sequential(*modules_down3_fre) + self.down3_fre_mo = common.FreBlock9(args.num_features, args) + + self.neck_fre = common.FreBlock9(args.num_features, args) + + self.neck_fre_mo = common.FreBlock9(args.num_features, args) + + modules_up1_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up1_fre = Block_Sequential(*modules_up1_fre) + self.up1_fre_mo = common.FreBlock9(args.num_features, args) + + modules_up2_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up2_fre = Block_Sequential(*modules_up2_fre) + self.up2_fre_mo = common.FreBlock9(args.num_features, args) + + modules_up3_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up3_fre = Block_Sequential(*modules_up3_fre) + self.up3_fre_mo = common.FreBlock9(args.num_features, args) + + # define tail module + self.tail_fre = common.ConvBNReLU2D(args.num_features, out_channels=args.num_channels, kernel_size=3, padding=1, + act=args.act) + + def init_T2_spa_branch(self, args, num_every_group): + ### spatial branch + modules_head = [] + self.head = common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act) + + modules_down1 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + ] + self.down1 = Block_Sequential(*modules_down1) + + + self.down1_mo = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + modules_down2 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + ] + self.down2 = Block_Sequential(*modules_down2) + + self.down2_mo = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + modules_down3 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + ] + self.down3 = Block_Sequential(*modules_down3) + self.down3_mo = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + self.neck = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + self.neck_mo = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + modules_up1 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + ] + self.up1 = Block_Sequential(*modules_up1) + + self.up1_mo = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + modules_up2 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + ] + self.up2 = Block_Sequential(*modules_up2) + self.up2_mo = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + + modules_up3 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + ] + self.up3 = Block_Sequential(*modules_up3) + self.up3_mo = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None, temp_channel=args.temb_channels) + + # define tail module + self.tail = common.ConvBNReLU2D(args.num_features, out_channels=args.num_channels, kernel_size=3, padding=1, + act=args.act) + + def init_T2_fre_spa_fusion(self, args): + ### T2 frq & spa fusion part + conv_fuse = [] + for i in range(14): + conv_fuse.append(common.FuseBlock7(args.num_features)) + self.conv_fuse = nn.Sequential(*conv_fuse) + + def init_T1_frq_branch(self, args): + ### T2frequency branch + self.head_fre_T1 = common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act) + + modules_down1_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + + self.down1_fre_T1 = Block_Sequential(*modules_down1_fre) + self.down1_fre_mo_T1 = common.FreBlock9(args.num_features, args) + + modules_down2_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down2_fre_T1 = Block_Sequential(*modules_down2_fre) + + self.down2_fre_mo_T1 = common.FreBlock9(args.num_features, args) + + modules_down3_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down3_fre_T1 = Block_Sequential(*modules_down3_fre) + self.down3_fre_mo_T1 = common.FreBlock9(args.num_features, args) + + + self.neck_fre_T1 = common.FreBlock9(args.num_features, args) + self.neck_fre_mo_T1 = common.FreBlock9(args.num_features, args) + + def init_T1_spa_branch(self, args, num_every_group): + ### spatial branch + self.head_T1 = common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act) + + modules_down1 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down1_T1 = Block_Sequential(*modules_down1) + + + self.down1_mo_T1 = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + + modules_down2 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down2_T1 = Block_Sequential(*modules_down2) + + self.down2_mo_T1 = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + + modules_down3 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down3_T1 = Block_Sequential(*modules_down3) + self.down3_mo_T1 = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + + self.neck_T1 = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + + self.neck_mo_T1 = common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + + + def init_modality_fre_fusion(self, args): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(args.num_features)) + self.conv_fuse_fre = nn.Sequential(*conv_fuse) + + def init_modality_spa_fusion(self, args): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(args.num_features)) + self.conv_fuse_spa = nn.Sequential(*conv_fuse) + + + def forward(self, main, aux, t=None): + + # self.temb_proj = torch.nn.Linear(temb_channels, + # out_channels) + # h = self.norm1(h) + # h = nonlinearity(h) + # h = self.conv1(h) + # + # h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + + temb = None + + + if not isinstance(t, type(None)): + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + + #### T1 fre encoder + t1_fre = self.head_fre_T1(aux) # 128 + + down1_fre_t1 = self.down1_fre_T1(t1_fre)# 64 + down1_fre_mo_t1 = self.down1_fre_mo_T1(down1_fre_t1) + + down2_fre_t1 = self.down2_fre_T1(down1_fre_mo_t1) # 32 + down2_fre_mo_t1 = self.down2_fre_mo_T1(down2_fre_t1) + + down3_fre_t1 = self.down3_fre_T1(down2_fre_mo_t1) # 16 + down3_fre_mo_t1 = self.down3_fre_mo_T1(down3_fre_t1) + + neck_fre_t1 = self.neck_fre_T1(down3_fre_mo_t1) # 16 + neck_fre_mo_t1 = self.neck_fre_mo_T1(neck_fre_t1) + + + #### T2 fre encoder and T1 & T2 fre fusion + x_fre = self.head_fre(main) # 128 + x_fre_fuse = self.conv_fuse_fre[0](t1_fre, x_fre) + + down1_fre = self.down1_fre(x_fre_fuse, temb)# 64 + down1_fre_mo = self.down1_fre_mo(down1_fre, temb) + down1_fre_mo_fuse = self.conv_fuse_fre[1](down1_fre_mo_t1, down1_fre_mo) + + down2_fre = self.down2_fre(down1_fre_mo_fuse, temb) # 32 + down2_fre_mo = self.down2_fre_mo(down2_fre, temb) + down2_fre_mo_fuse = self.conv_fuse_fre[2](down2_fre_mo_t1, down2_fre_mo) + + down3_fre = self.down3_fre(down2_fre_mo_fuse, temb) # 16 + down3_fre_mo = self.down3_fre_mo(down3_fre, temb) + down3_fre_mo_fuse = self.conv_fuse_fre[3](down3_fre_mo_t1, down3_fre_mo) + + neck_fre = self.neck_fre(down3_fre_mo_fuse, temb) # 16 + neck_fre_mo = self.neck_fre_mo(neck_fre, temb) + neck_fre_mo_fuse = self.conv_fuse_fre[4](neck_fre_mo_t1, neck_fre_mo) + + + #### T2 fre decoder + neck_fre_mo = neck_fre_mo_fuse + down3_fre_mo_fuse + + up1_fre = self.up1_fre(neck_fre_mo, temb) # 32 + up1_fre_mo = self.up1_fre_mo(up1_fre, temb) + up1_fre_mo = up1_fre_mo + down2_fre_mo_fuse + + up2_fre = self.up2_fre(up1_fre_mo, temb) # 64 + up2_fre_mo = self.up2_fre_mo(up2_fre, temb) + up2_fre_mo = up2_fre_mo + down1_fre_mo_fuse + + up3_fre = self.up3_fre(up2_fre_mo, temb) # 128 + up3_fre_mo = self.up3_fre_mo(up3_fre, temb) + up3_fre_mo = up3_fre_mo + x_fre_fuse + + res_fre = self.tail_fre(up3_fre_mo) + + #### T1 spa encoder + x_t1 = self.head_T1(aux) # 128 + + down1_t1 = self.down1_T1(x_t1) # 64 + down1_mo_t1 = self.down1_mo_T1(down1_t1) + + down2_t1 = self.down2_T1(down1_mo_t1) # 32 + down2_mo_t1 = self.down2_mo_T1(down2_t1) # 32 + + down3_t1 = self.down3_T1(down2_mo_t1) # 16 + down3_mo_t1 = self.down3_mo_T1(down3_t1) # 16 + + neck_t1 = self.neck_T1(down3_mo_t1) # 16 + neck_mo_t1 = self.neck_mo_T1(neck_t1) + + #### T2 spa encoder and fusion + x = self.head(main) # 128 temb + + x_fuse = self.conv_fuse_spa[0](x_t1, x) + down1 = self.down1(x_fuse, temb) # 64 + down1_fuse = self.conv_fuse[0](down1_fre, down1) + down1_mo = self.down1_mo(down1_fuse, temb) + down1_fuse_mo = self.conv_fuse[1](down1_fre_mo_fuse, down1_mo) + + down1_fuse_mo_fuse = self.conv_fuse_spa[1](down1_mo_t1, down1_fuse_mo) + down2 = self.down2(down1_fuse_mo_fuse, temb) # 32 + down2_fuse = self.conv_fuse[2](down2_fre, down2) + down2_mo = self.down2_mo(down2_fuse, temb) # 32 + down2_fuse_mo = self.conv_fuse[3](down2_fre_mo, down2_mo) + + down2_fuse_mo_fuse = self.conv_fuse_spa[2](down2_mo_t1, down2_fuse_mo) + down3 = self.down3(down2_fuse_mo_fuse, temb) # 16 + down3_fuse = self.conv_fuse[4](down3_fre, down3) + down3_mo = self.down3_mo(down3_fuse, temb) # 16 + down3_fuse_mo = self.conv_fuse[5](down3_fre_mo, down3_mo) + + down3_fuse_mo_fuse = self.conv_fuse_spa[3](down3_mo_t1, down3_fuse_mo) + neck = self.neck(down3_fuse_mo_fuse, temb) # 16 + neck_fuse = self.conv_fuse[6](neck_fre, neck) + neck_mo = self.neck_mo(neck_fuse, temb) + neck_mo = neck_mo + down3_mo + neck_fuse_mo = self.conv_fuse[7](neck_fre_mo, neck_mo) + + neck_fuse_mo_fuse = self.conv_fuse_spa[4](neck_mo_t1, neck_fuse_mo) + #### T2 spa decoder + up1 = self.up1(neck_fuse_mo_fuse, temb) # 32 + up1_fuse = self.conv_fuse[8](up1_fre, up1) + up1_mo = self.up1_mo(up1_fuse, temb) + up1_mo = up1_mo + down2_mo + up1_fuse_mo = self.conv_fuse[9](up1_fre_mo, up1_mo) + + up2 = self.up2(up1_fuse_mo, temb) # 64 + up2_fuse = self.conv_fuse[10](up2_fre, up2) + up2_mo = self.up2_mo(up2_fuse, temb) + up2_mo = up2_mo + down1_mo + up2_fuse_mo = self.conv_fuse[11](up2_fre_mo, up2_mo) + + up3 = self.up3(up2_fuse_mo, temb) # 128 + + up3_fuse = self.conv_fuse[12](up3_fre, up3) + up3_mo = self.up3_mo(up3_fuse, temb) + + up3_mo = up3_mo + x + up3_fuse_mo = self.conv_fuse[13](up3_fre_mo, up3_mo) + + # import matplotlib.pyplot as plt + # plt.axis('off') + # plt.imshow((255*up3_fre_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_fre_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + + # plt.axis('off') + # plt.imshow((255*up3_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + + # plt.axis('off') + # plt.imshow((255*up3_fuse_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_fuse_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + # breakpoint() + + res = self.tail(up3_fuse_mo) + + return {'img_out': res + main, 'img_fre': res_fre + main} + +def make_model(args): + return DiffTwoBranch(args) + + + +if __name__ == "__main__": + # Test the model + from utils.option import args + + network = DiffTwoBranch(args) + # network = build_model_from_name(args).cuda() + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + # network = nn.SyncBatchNorm.convert_sync_batchnorm(network) + + n_parameters = sum(p.numel() for p in network.parameters() if p.requires_grad) + print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + # Test the model + data = torch.randn(1, 1, 128, 128)#.cuda + out = network(data, data) + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/results.md b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/results.md new file mode 100644 index 0000000000000000000000000000000000000000..1be41330341c639beeddcc0853efd6de001742a5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/results.md @@ -0,0 +1,110 @@ + + +# Brats 8X (Ready) + +results/brats_8x/ + + + +python test_brats.py --root_path $root_path_8x \ + --gpu 1 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_BraTS_8x --phase test --use_time_model --use_kspace \ + --test_sample Ksample + + + +Results +------------------------------------ +Ksample NMSE: 0.4580 ± 0.1765 +Ksample PSNR: 39.8750 ± 1.8064 +Ksample SSIM: 0.9810 ± 0.0053 +----------------------------------- + + + + + +# Brats 4X +results/brats_4x + + +Results +------------------------------------ +Ksample NMSE: 0.3968 ± 0.1552 +Ksample PSNR: 40.5047 ± 1.8174 +Ksample SSIM: 0.9834 ± 0.0044 +------------------------------------ + + + + + + + +# Brats 12X (Ready) + +results/brats_12x + +------------------------------------ +Ksample NMSE: 0.5821 ± 0.2252 +Ksample PSNR: 38.8360 ± 1.8016 +Ksample SSIM: 0.9778 ± 0.0063 +------------------------------------ + + + + + +# M4raw 4X t5 +results/m4raw_4x + + +------------------------------------ +Ksample NMSE: 2.4857 ± 0.1376 +Ksample PSNR: 29.9710 ± 0.4033 +Ksample SSIM: 0.7991 ± 0.0115 +------------------------------------ + + + +# FastMRI 4X t5 (Ready) + +results/fastmri_4x/ + +------------------------------------ +NMSE: 2.5041 ± 0.6921 +PSNR: 30.8036 ± 1.8245 +SSIM: 0.7506 ± 0.0418 +------------------------------------ + + +# FastMRI 8X t5 (Ready) +results/fastmri_8x/ + +------------------------------------ +NMSE: 3.6736 ± 0.9377 +PSNR: 29.1211 ± 1.7199 +SSIM: 0.6225 ± 0.0654 +------------------------------------ + +% Save Path: model/FSMNet_fastmri_8x_t5_kspace_time//result_case/ + + + + + + +# FastMRI 12X t5 (Ready) + + +results/fastmri_12x/ + + +------------------------------------ +NMSE: 4.2541 ± 1.0271 +PSNR: 28.4743 ± 1.6616 +SSIM: 0.5642 ± 0.0772 +------------------------------------ + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/test_brats.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/test_brats.py new file mode 100644 index 0000000000000000000000000000000000000000..2a2b403925e7c8a1237f490a3471222d6061eee3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/test_brats.py @@ -0,0 +1,349 @@ +import os +import sys +from tqdm import tqdm +import argparse +import logging +from skimage import io + +from torchvision import transforms +from torch.utils.data import DataLoader +from dataloaders.BRATS_dataloader_new import Hybrid as MyDataset +from dataloaders.BRATS_dataloader_new import ToTensor +from networks.mynet import TwoBranch +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity +from utils.option import args + + +def normalise_mse(gt, pred): + """Compute Normalized Mean Squared Error (NMSE)""" + return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2 + + + +parser = argparse.ArgumentParser() +parser.add_argument('--root_path', type=str, default='/home/xiaohan/datasets/BRATS_dataset/BRATS_2020_images/selected_images/') +parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') +parser.add_argument('--low_field_SNR', type=int, default=15, help='SNR of the simulated low-field image') +parser.add_argument('--phase', type=str, default='test', help='Name of phase') +parser.add_argument('--gpu', type=str, default='0', help='GPU to use') +parser.add_argument('--exp', type=str, default='msl_model', help='model_name') +parser.add_argument('--seed', type=int, default=1337, help='random seed') +parser.add_argument('--base_lr', type=float, default=0.0002, help='maximum epoch numaber to train') + +parser.add_argument('--model_name', type=str, default='unet_single', help='model_name') +parser.add_argument('--relation_consistency', type=str, default='False', help='regularize the consistency of feature relation') +parser.add_argument('--norm', type=str, default='False', help='Norm Layer between UNet and Transformer') +parser.add_argument('--input_normalize', type=str, default='mean_std', help='choose from [min_max, mean_std, divide]') +parser.add_argument('--test_sample', default="Ksample", help="Ksample | ColdDiffusion | DDPM") + +# args = parser.parse_args() + +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + + +from utils.utils import * +from frequency_diffusion.degradation.k_degradation import get_ksu_kernel, apply_tofre, apply_to_spatial +from networks_time.mynet import DiffTwoBranch + +DEBUG = args.DEBUG +use_time_model = args.use_time_model +use_kspace = args.use_kspace +use_t2_in = True + +num_timesteps = args.num_timesteps +image_size = 240 + + +if args.MRIDOWN == "4X": + accelerate_mask = np.load("./dataloaders/example_mask/brats_4X_mask.npy") + accelerate_mask = torch.from_numpy(accelerate_mask).unsqueeze(0).clone().float() + print("accelerate_mask shape =", accelerate_mask.shape) + +else: + accelerate_mask = None + +k_file = f"./dataloaders/example_mask/brats_{args.ACCELERATIONS[0]}_kspace_mask.npy" + +if os.path.exists(k_file): + kspace_masks = np.load(k_file) + kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + print("Use existing kfile:", k_file) +else: + # Output a list of k-space kernels + kspace_masks = get_ksu_kernel(num_timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=args.ACCELERATIONS[0], + accelerate_mask=accelerate_mask + ) + kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + + + +test_sample = args.test_sample # Ksample | ColdDiffusion | DDPM + +if __name__ == "__main__": + ## make logger file + if use_kspace: + snapshot_path = snapshot_path.rstrip("/") + f'_t{num_timesteps}_kspace/' + + if use_time_model: + snapshot_path = snapshot_path.rstrip("/") + '_time/' + + if not isinstance(args.test_tag, type(None)): + snapshot_path = snapshot_path.rstrip("/") + f'_{args.test_tag}/' + + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + if use_time_model: + network = DiffTwoBranch(args).cuda() + else: + network = TwoBranch(args).cuda() + + + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + + n_parameters = sum(p.numel() for p in network.parameters() if p.requires_grad) + print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + db_test = MyDataset(split='test', MRIDOWN=args.MRIDOWN, SNR=args.low_field_SNR, + transform=transforms.Compose([ToTensor()]), + base_dir=test_data_path, input_normalize = args.input_normalize) + + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=2, pin_memory=True) + + + if args.phase == 'test': + + save_mode_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + print('load weights from ' + save_mode_path) + + + checkpoint = torch.load(save_mode_path) + network.load_state_dict(checkpoint['network']) + network.eval() + cnt = 0 + save_path = snapshot_path + '/result_case/' + feature_save_path = snapshot_path + '/feature_visualization/' + if not os.path.exists(save_path): + os.makedirs(save_path) + if not os.path.exists(feature_save_path): + os.makedirs(feature_save_path) + + + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all, t2_NMSE_all = [], [], [], [] + + for (sampled_batch, sample_stats) in tqdm(testloader, ncols=70): + cnt += 1 + + print('processing ' + str(cnt) + ' image') + + t1_in, t1, t2_in, t2 = sampled_batch['image_in'].cuda(), sampled_batch['image'].cuda(), \ + sampled_batch['target_in'].cuda(), sampled_batch['target'].cuda() + t1_krecon, t2_krecon = sampled_batch['image_krecon'].cuda(), sampled_batch['target_krecon'].cuda() + + t2_mean = sample_stats['t2_mean'].data.cpu().numpy()[0] + t2_std = sample_stats['t2_std'].data.cpu().numpy()[0] + + + t1_out, t2_out = None, None + + if use_kspace: + b = t2.shape[0] + t = torch.randint(num_timesteps - 1, num_timesteps, (b,), device=device).long() # t-1 + + mask = kspace_masks[t] + + + target_fft, _ = apply_tofre(t2.clone(), mask) + fft, mask = apply_tofre(t2_in.clone(), mask) + + fft = target_fft * mask + fft * (1 - mask) # Seems too easy + t2_in = apply_to_spatial(fft) + + + while t >= 0: + # for t in range(t, -1, -1): + if use_time_model: + outputs = network(t2_in, t1_in, t)['img_out'] # t1_in? + else: + outputs = network(t2_in, t1_in)['img_out'] + + if t == 0: + mask = kspace_masks[0] # last one + t2_in = outputs + if use_time_model: + t2_out_2 = network(t2_in, t1_in, t)['img_out'] + else: + t2_out_2 = network(t2_in, t1_in)['img_out'] + else: + + if test_sample == "Ksample": # Ksample | ColdDiffusion | DDPM + + k_full = kspace_masks[-1] + t2_in_fre, k_full = apply_tofre(t2_in, k_full) + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] # get_kspace_kernels(t - 2).cuda() + kt = kspace_masks[t] # self.get_kspace_kernels(t - 1).cuda() # last one + k_residual = kt_sub_1 - kt + + recon_sample_fre, k_residual = apply_tofre(outputs, k_residual) + + t2_in_fre = t2_in_fre * (1 - k_residual) + recon_sample_fre * k_residual + + outputs = apply_to_spatial(t2_in_fre) + t2_in = outputs + + + elif test_sample == "KsampleAR": # Ksample | ColdDiffusion | DDPM + + k_full = kspace_masks[-1] + t2_in_fre, k_full = apply_tofre(t2_in, k_full) + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] # get_kspace_kernels(t - 2).cuda() + kt = kspace_masks[t] # self.get_kspace_kernels(t - 1).cuda() # last one + k_residual = kt_sub_1 # - kt + + recon_sample_fre, k_residual = apply_tofre(outputs, k_residual) + # fre_amend = recon_sample_fre * k_residual + + t2_in_fre = t2_in_fre * (1 - k_residual) + recon_sample_fre * k_residual + + outputs = apply_to_spatial(t2_in_fre) + t2_in = outputs + + + elif test_sample == "ColdDiffusion": + k_full = kspace_masks[-1] + # t2_in_fre, k_full = apply_tofre(t2_in, k_full) + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] # get_kspace_kernels(t - 2).cuda() + kt = kspace_masks[t] # self.get_kspace_kernels(t - 1).cuda() # last one + + k_residual = kt_sub_1 - kt + + recon_sample_fre, k_residual = apply_tofre(outputs, k_residual) + + x_t_hat_fre = recon_sample_fre * kt + x_t_sub_1_hat_fre = recon_sample_fre * kt_sub_1 + + x_t_hat = apply_to_spatial(x_t_hat_fre) + x_t_sub_1_hat = apply_to_spatial(x_t_sub_1_hat_fre) + + outputs = t2_in - x_t_hat + x_t_sub_1_hat + + t2_in = outputs + + elif test_sample == "DDPM": + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] # get_kspace_kernels(t - 2).cuda() + + recon_sample_fre, kt_sub_1 = apply_tofre(outputs, kt_sub_1) + fre_new = recon_sample_fre * kt_sub_1 + + outputs = apply_to_spatial(fre_new) + t2_in = outputs + + t = t - 1 + t2_out = outputs + + else: + t2_out = network(t2_in, t1_in)['img_out'] + t2_out_2 = network(t2_in, t1_in)['img_out'] + + t1_mean = sample_stats['t1_mean'].data.cpu().numpy()[0] + t1_std = sample_stats['t1_std'].data.cpu().numpy()[0] + t2_mean = sample_stats['t2_mean'].data.cpu().numpy()[0] + t2_std = sample_stats['t2_std'].data.cpu().numpy()[0] + + + + + # k_full = kspace_masks[-1] # full true mask + # t2_in_fre, k_full = apply_tofre(t2_in, k_full) + # outputs = apply_to_spatial(t2_in_fre) + + + + + if t1_out is not None: + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_out_img = (np.clip(t1_out.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_krecon_img = (np.clip(t1_krecon.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + + + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t2_in_img = (np.clip(t2_in.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_img = (np.clip(t2.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_out_img = (np.clip(t2_out.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_krecon_img = (np.clip(t2_krecon.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_out_2_img = (np.clip(t2_out_2.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + + + io.imsave(save_path + str(cnt) + '_t1.png', bright(t1_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2.png', bright(t2_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2_original.png', t2_img) + io.imsave(save_path + str(cnt) + '_t2_in.png', bright(t2_in_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2_out_original.png', t2_out_img) + io.imsave(save_path + str(cnt) + '_t2_out.png', bright(t2_out_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2_out2.png', bright(t2_out_2_img,0,0.8)) + + # ------------------------------------ + # NMSE: 5.4534 ± 1.5515 + # PSNR: 39.2132 ± 1.6888 + # SSIM: 0.9792 ± 0.0054 + # ------------------------------------ + # Save Path: model/FSMNet_BraTS_8x_kspace//result_case/ + + if t2_out is not None: + t2_out_img[t2_out_img < 0.0] = 0.0 + t2_img[t2_img < 0.0] = 0.0 + + MSE = mean_squared_error(t2_img, t2_out_img) + PSNR = peak_signal_noise_ratio(t2_img, t2_out_img) + SSIM = structural_similarity(t2_img, t2_out_img) + nmse = normalise_mse(t2_img/255, t2_out_img/255) + + t2_MSE_all.append(MSE) + t2_PSNR_all.append(PSNR) + t2_SSIM_all.append(SSIM) + t2_NMSE_all.append(nmse) + + print("[t2 MRI] MSE:", MSE, "PSNR:", PSNR, "SSIM:", SSIM, "NMSE:", nmse) + + + print("===> Evaluate Metric <===") + print("Results") + print("-" * 36) + print(f"{test_sample} NMSE: {np.array(t2_NMSE_all).mean() * 100:.4f} ± {np.array(t2_NMSE_all).std() * 100 :.4f}") + # print(f"MSE: {np.array(t2_MSE_all).mean():.4f} ± {np.array(t2_MSE_all).std():.4f}") + print(f"{test_sample} PSNR: {np.array(t2_PSNR_all).mean():.4f} ± {np.array(t2_PSNR_all).std():.4f}") + print(f"{test_sample} SSIM: {np.array(t2_SSIM_all).mean():.4f} ± {np.array(t2_SSIM_all).std():.4f}") + print("-" * 36) + print(f"Save Path: {save_path}") + + + + # print("[T2 MRI:] average MSE:", np.array(t2_MSE_all).mean(), "average PSNR:", np.array(t2_PSNR_all).mean(), "average SSIM:", np.array(t2_SSIM_all).mean()) + # print("[T2 MRI:] average MSE:", np.array(t2_MSE_all).std(), "average PSNR:", np.array(t2_PSNR_all).std(), "average SSIM:", np.array(t2_SSIM_all).std()) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/test_fastmri.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/test_fastmri.py new file mode 100644 index 0000000000000000000000000000000000000000..3945b6e115f2219363a5bd30ae713e4e77c04abd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/test_fastmri.py @@ -0,0 +1,267 @@ +import os +import sys +import logging +from skimage import io +from skimage import img_as_ubyte + +from torch.utils.data import DataLoader +from networks.mynet import TwoBranch + +from utils.option import args +from tqdm import tqdm +from utils.metric import nmse, psnr, ssim +from collections import defaultdict +from networks_time.mynet import DiffTwoBranch + +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + + +def normalize_output(out_img): + out_img = (out_img - out_img.min())/(out_img.max() - out_img.min() + 1e-8) + return out_img + +from frequency_diffusion.degradation.k_degradation import apply_tofre, apply_to_spatial +from utils.utils import * + +DEBUG = False +use_kspace = args.use_kspace +use_time_model = args.use_time_model +num_timesteps = args.num_timesteps +image_size = args.image_size + +from frequency_diffusion.degradation.k_degradation import get_ksu_kernel, apply_tofre, apply_to_spatial + +kfile = f"./dataloaders/example_mask/kspace_{args.ACCELERATIONS[0]}_mask.npy" +print("using kfile:", kfile) + + +kspace_masks = np.load(kfile) +kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + +kspace_masks = get_ksu_kernel(num_timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=args.ACCELERATIONS[0] + ) + +kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + + + +print("kspace_masks shape: ", kspace_masks.shape) + +@torch.no_grad() +def evaluate(model, data_loader, device, save_path): + os.makedirs(save_path, exist_ok=True) + + model.eval() + nmse_meter, psnr_meter, ssim_meter = [], [], [] + direct_nmse, direct_psnr, direct_ssim = [], [], [] + output_dic = defaultdict(dict) + target_dic = defaultdict(dict) + input_dic = defaultdict(dict) + direct_recon_dic = defaultdict(dict) + + flag=0 + last_name='no' + + print("len of data_loader: ", len(data_loader)) + + for id, data in tqdm(enumerate(data_loader)): + pd, pdfs, _ = data + name = os.path.basename(pdfs[4][0]).split('.')[0] + + target = pdfs[1].to(device) + + mean, std = pdfs[2], pdfs[3] + + fname = pdfs[4] + slice_num = pdfs[5] + + mean = mean.unsqueeze(1).unsqueeze(2).to(device) + std = std.unsqueeze(1).unsqueeze(2).to(device) + + pd_img = pd[1].unsqueeze(1).to(device) + pdfs_img = pdfs[0].unsqueeze(1).to(device) + + pdfs_img_origin = pdfs_img.clone() + + # Degradation + if use_kspace: + b = pdfs_img.shape[0] + # t = torch.randint(num_timesteps - 1, num_timesteps, (b,), device=device).long() # t-1 + t = num_timesteps - 1 + t = torch.tensor([t], device=device).long() # t-1 + + mask = kspace_masks[t] + + fft, mask = apply_tofre(target.clone(), mask) + fft = fft * mask + 0.0 + pdfs_img = apply_to_spatial(fft) + + while t >= 0: + if use_time_model: + outputs = network(pdfs_img, pd_img, t)['img_out'] + else: + outputs = network(pdfs_img, pd_img)['img_out'] + + + if t == num_timesteps - 1: + direct_recon = outputs + + if t == 0: + mask = kspace_masks[0] # last one + pdfs_img = outputs + + else: + k_full = kspace_masks[-1] + faded_recon_sample_fre, k_full = apply_tofre(pdfs_img, k_full) + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] # get_kspace_kernels(t - 2).cuda() + kt = kspace_masks[t] # self.get_kspace_kernels(t - 1).cuda() # last one + k_residual = kt_sub_1 - kt + + recon_sample_fre, k_residual = apply_tofre(outputs, k_residual) + + faded_recon_sample_fre = faded_recon_sample_fre * (1 - k_residual) + recon_sample_fre * k_residual + + outputs = apply_to_spatial(faded_recon_sample_fre) + pdfs_img = outputs + + + t = t-1 + + else: + outputs = network(pdfs_img, pd_img)['img_out'] + + outputs = outputs.squeeze(1) + direct_recon = direct_recon.squeeze(1) + + outputs_save = outputs[0].cpu().clone().numpy()/6.0 + outputs_save = np.clip(outputs_save, a_min=-1, a_max=1) + target_save = target[0].cpu().clone().numpy()/6.0 + in_save = pdfs_img_origin[0][0].cpu().clone().numpy()/6.0 + + # Not sure if it was correct to convert to ubyte + outputs_save = img_as_ubyte(outputs_save) + target_save = img_as_ubyte(target_save) + in_save = img_as_ubyte(in_save) + + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '.png', target_save) + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '_in.png', in_save) + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '_out.png', outputs_save) + + outputs = outputs * std + mean + target = target * std + mean + inputs = pdfs_img_origin.squeeze(1) * std + mean + direct_recon = direct_recon * std + mean + + output_dic[fname[0]][slice_num[0]] = outputs[0] + target_dic[fname[0]][slice_num[0]] = target[0] + input_dic[fname[0]][slice_num[0]] = inputs[0] + direct_recon_dic[fname[0]][slice_num[0]] = direct_recon[0] + + # print("target/outputs shape: ", target.shape, outputs.shape) + our_nmse = nmse(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + our_psnr = psnr(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + our_ssim = ssim(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + + # print('name:{}, slice:{}, nmse:{}, psnr:{}, ssim:{}'.format(name, slice_num[0], our_nmse, our_psnr, our_ssim)) + + + for name in output_dic.keys(): + f_output = torch.stack([v for _, v in output_dic[name].items()]) + f_target = torch.stack([v for _, v in target_dic[name].items()]) + our_nmse = nmse(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_psnr = psnr(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_ssim = ssim(f_target.cpu().numpy(), f_output.cpu().numpy()) + + nmse_meter.append(our_nmse) + psnr_meter.append(our_psnr) + ssim_meter.append(our_ssim) + + direct_nmse.append(nmse(f_target.cpu().numpy(), torch.stack([v for _, v in direct_recon_dic[name].items()]).cpu().numpy())) + direct_psnr.append(psnr(f_target.cpu().numpy(), torch.stack([v for _, v in direct_recon_dic[name].items()]).cpu().numpy())) + direct_ssim.append(ssim(f_target.cpu().numpy(), torch.stack([v for _, v in direct_recon_dic[name].items()]).cpu().numpy())) + + nmse_meter_score = np.array(nmse_meter) + psnr_meter_score = np.array(psnr_meter) + ssim_meter_score = np.array(ssim_meter) + + direct_nmse_score = np.array(direct_nmse) + direct_psnr_score = np.array(direct_psnr) + direct_ssim_score = np.array(direct_ssim) + + print("===> Evaluate Metric <===") + print("Direct Results") + print("-" * 36) + print(f"NMSE: {np.mean(direct_nmse_score) * 100:.4f} ± {np.std(direct_nmse_score) * 100:.4f}") + print(f"PSNR: {np.mean(direct_psnr_score):.4f} ± {np.std(direct_psnr_score):.4f}") + print(f"SSIM: {np.mean(direct_ssim_score):.4f} ± {np.std(direct_ssim_score):.4f}") + print("-" * 36) + + print("===> Evaluate Metric <===") + print("Results") + print("-" * 36) + print(f"NMSE: {np.mean(nmse_meter_score) * 100:.4f} ± {np.std(nmse_meter_score) * 100:.4f}") + print(f"PSNR: {np.mean(psnr_meter_score):.4f} ± {np.std(psnr_meter_score):.4f}") + print(f"SSIM: {np.mean(ssim_meter_score):.4f} ± {np.std(ssim_meter_score):.4f}") + print("-" * 36) + print(f"Save Path: {save_path}") + + model.train() + return {'NMSE': np.mean(nmse_meter_score), 'PSNR': np.mean(psnr_meter_score), 'SSIM': np.mean(ssim_meter_score)} + + +from dataloaders.fastmri import build_dataset +if __name__ == "__main__": + ## make logger file + if use_kspace: + snapshot_path = snapshot_path.rstrip("/") + f'_t{num_timesteps}_kspace/' + if use_time_model: + snapshot_path = snapshot_path.rstrip("/") + '_time/' + + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + if use_time_model: + network = DiffTwoBranch(args).cuda() + else: + network = TwoBranch(args).cuda() + + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + + db_test = build_dataset(args, mode='val', use_kspace=use_kspace) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + if args.phase == 'test': + + save_mode_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + # save_mode_path = os.path.join(snapshot_path, 'iter_100000.pth') + print('load weights from ' + save_mode_path) + + checkpoint = torch.load(save_mode_path) + + weights_dict = {} + for k, v in checkpoint['network'].items(): + new_k = k.replace('module.', '') if 'module' in k else k + weights_dict[new_k] = v + + network.load_state_dict(weights_dict) + network.eval() + + eval_result = evaluate(network, testloader, device, save_path = snapshot_path + '/result_case/') + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/test_m4raw.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/test_m4raw.py new file mode 100644 index 0000000000000000000000000000000000000000..f53fea8715c3c9fe210ceaff5efdeb4b90d89f9e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/test_m4raw.py @@ -0,0 +1,353 @@ +import os +import sys +import logging +from skimage import io +from skimage import img_as_ubyte + +from torch.utils.data import DataLoader +from networks.mynet import TwoBranch + +from utils.option import args +from tqdm import tqdm +from utils.metric import nmse, psnr, ssim +from collections import defaultdict +from networks_time.mynet import DiffTwoBranch + +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +use_new_dataloader = True + + +# Results + + +def normalize_output(out_img): + out_img = (out_img - out_img.min()) / (out_img.max() - out_img.min() + 1e-8) + return out_img + + +from frequency_diffusion.degradation.k_degradation import apply_tofre, apply_to_spatial, apply_ksu_kernel +from utils.utils import * + + +num_timesteps = args.num_timesteps +image_size = args.image_size +distortion_sigma = 10 / 255 +use_kspace = args.use_kspace +use_time_model = args.use_time_model +DEBUG = args.DEBUG + + +kfile=f"./dataloaders/example_mask/m4raw_{args.ACCELERATIONS[0]}_mask.npy" +print("kfile = ", kfile) + + +kspace_masks = np.load(kfile) +kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + +test_sample = args.test_sample # Ksample | ColdDiffusion | DDPM +frequency_distortion = True + +@torch.no_grad() +def evaluate(model, data_loader, device, save_path): + os.makedirs(save_path, exist_ok=True) + + model.eval() + nmse_meter = [] + psnr_meter = [] + ssim_meter = [] + nmse_meter_all = [] + psnr_meter_all = [] + ssim_meter_all = [] + output_dic = {} # defaultdict(dict) + target_dic = {} # efaultdict(dict) + input_dic = {} # defaultdict(dict) + + flag = 0 + last_name = 'no' + + print("len of data_loader: ", len(data_loader)) + + for id, sampled_batch in tqdm(enumerate(data_loader)): + t1_img, t1_in = sampled_batch['t1'], sampled_batch['t1_in'] + t2_img, t2_in = sampled_batch['t2'], sampled_batch['t2_in'] + + t1_img = t1_img.to(device) + t1_in = t1_in.to(device) + t2_img = t2_img.to(device) + t2_in = t2_in.to(device) + + mean, std = sampled_batch['t2_mean'], sampled_batch['t2_std'] + + name = sampled_batch['fname'] + fname = [name] + slice_num = sampled_batch['slice'] + + mean = mean.unsqueeze(1).unsqueeze(2).to(device) + std = std.unsqueeze(1).unsqueeze(2).to(device) + + t2_in_origin = t2_in.clone() + + # Degradation + if use_kspace: + b = 1 + t = torch.randint(num_timesteps - 1, num_timesteps, (b,), device=device).long() # t-1 + mask = kspace_masks[t] + + fft, mask = apply_tofre(t2_in.clone(), mask) # t2_img + fft = fft * mask + 0.0 + t2_in = apply_to_spatial(fft) + t2_in_origin = t2_in.clone() + + while t >= 0: + + # outputs = model(t2_in, t1_img)['img_out'] + if use_time_model: + outputs = model(t2_in, t1_img, t)['img_out'] + else: + outputs = model(t2_in, t1_img)['img_out'] + + if t == 0: + mask = kspace_masks[0] # last one + t2_in = outputs + + else: + if test_sample == "Ksample": # Ksample | ColdDiffusion | DDPM + k_full = kspace_masks[-1] + faded_recon_sample_fre, k_full = apply_tofre(t2_in, k_full) + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] + kt = kspace_masks[t] + k_residual = kt_sub_1 - kt + + recon_sample_fre, k_residual = apply_tofre(outputs, k_residual) + faded_recon_sample_fre = recon_sample_fre * k_residual + faded_recon_sample_fre * (1 - k_residual) + # faded_recon_sample_fre = faded_recon_sample_fre + fre_amend + + outputs = apply_to_spatial(faded_recon_sample_fre) + t2_in = outputs + + + elif test_sample == "ColdDiffusion": + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] + kt = kspace_masks[t] + + x_t_hat = apply_ksu_kernel(outputs, kt) + x_t_sub_1_hat = apply_ksu_kernel(outputs, kt_sub_1) + + outputs = t2_in - x_t_hat + x_t_sub_1_hat + + t2_in = outputs + + elif test_sample == "DDPM": + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] # get_kspace_kernels(t - 2).cuda() + outputs = apply_ksu_kernel(kt_sub_1, kt_sub_1) + t2_in = outputs + + + t = t - 1 + + else: + outputs = model(t2_in, t1_img)['img_out'] + + # print("outputs shape: ", outputs.shape, outputs.min(), outputs.max()) + # print("t2_img shape: ", t2_img.shape, t2_img.min(), t2_img.max()) + + target = t2_img.clone().squeeze(1) * std + mean + inputs = t2_in_origin.clone().squeeze(1) * std + mean + outputs_save = outputs.clone().squeeze(1) * std + mean + + outputs_save = outputs_save.cpu().numpy() + # outputs_save = np.clip(outputs_save, a_min=-1, a_max=1) + target_save = target.cpu().numpy() + in_save = inputs.cpu().numpy() + + _min, _max = target_save.min(), target_save.max() + target_save = (((target_save - _min) / (_max - _min)) * 255).astype(np.uint8) + in_save = (((in_save - _min) / (_max - _min)) * 255).astype(np.uint8) + outputs_save = (((outputs_save - _min) / (_max - _min)) * 255).astype(np.uint8) + + # Not sure if it was correct to convert to ubyte + outputs_save = img_as_ubyte(outputs_save) + target_save = img_as_ubyte(target_save) + in_save = img_as_ubyte(in_save) + + print("outputs_save shape: ", outputs_save.shape, outputs_save.min(), outputs_save.max()) + print("target_save shape: ", target_save.shape, target_save.min(), target_save.max()) + print("in_save shape: ", in_save.shape, in_save.min(), in_save.max()) + + if len(outputs_save.shape) > 3: + outputs_save = outputs_save.squeeze(0) + target_save = target_save.squeeze(0) + in_save = in_save.squeeze(0) + + if len(outputs_save.shape) > 3: + outputs_save = outputs_save.squeeze(0) + target_save = target_save.squeeze(0) + in_save = in_save.squeeze(0) + + name = name[0].numpy() + name_int = int(name) + + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '.png', target_save) + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '_in.png', in_save) + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '_out.png', outputs_save) + + outputs = outputs.squeeze(1) * std + mean + target = t2_img.squeeze(1) * std + mean + inputs = t2_in_origin.squeeze(1) * std + mean + + if name_int not in output_dic.keys(): + output_dic[name_int] = [] + target_dic[name_int] = [] + input_dic[name_int] = [] + + output_dic[name_int].append(outputs[0]) + target_dic[name_int].append(target[0]) + input_dic[name_int].append(inputs[0]) + + # print("target/outputs shape: ", target.shape, outputs.shape) + our_nmse = nmse(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + our_psnr = psnr(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + our_ssim = ssim(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + + print(' name:{}, slice:{}, nmse:{}, psnr:{}, ssim:{}'.format(name, slice_num[0], our_nmse, our_psnr, our_ssim)) + + nmse_meter_all.append(our_nmse) + psnr_meter_all.append(our_psnr) + ssim_meter_all.append(our_ssim) + # print("psnr_meter_all: ", np.mean(psnr_meter_all)) + + for name in output_dic.keys(): + print("name: ", name, len(output_dic[name])) + # f_output = torch.stack([v for _, v in output_dic[name].items()]) + # f_target = torch.stack([v for _, v in target_dic[name].items()]) + f_output = torch.stack(list(output_dic[name])) + f_target = torch.stack(list(target_dic[name])) + + print("f_output shape: ", f_output.shape) + + if len(f_output.shape) > 3: + f_output = f_output.squeeze(1) + f_target = f_target.squeeze(1) + + our_nmse = nmse(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_psnr = psnr(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_ssim = ssim(f_target.cpu().numpy(), f_output.cpu().numpy()) + + nmse_meter.append(our_nmse) + psnr_meter.append(our_psnr) + ssim_meter.append(our_ssim) + + nmse_meter_score = np.array(nmse_meter) + psnr_meter_score = np.array(psnr_meter) + ssim_meter_score = np.array(ssim_meter) + + nmse_meter_all_score = np.array(nmse_meter_all) + psnr_meter_all_score = np.array(psnr_meter_all) + ssim_meter_all_score = np.array(ssim_meter_all) + + print("===> Evaluate Metric <===") + print("Results") + print("-" * 36) + print(f"{test_sample} NMSE: {np.mean(nmse_meter_score) * 100:.4f} ± {np.std(nmse_meter_score) * 100:.4f}") + print(f"{test_sample} PSNR: {np.mean(psnr_meter_score):.4f} ± {np.std(psnr_meter_score):.4f}") + print(f"{test_sample} SSIM: {np.mean(ssim_meter_score):.4f} ± {np.std(ssim_meter_score):.4f}") + print("-" * 36) + print(f"All NMSE: {np.mean(nmse_meter_all_score) * 100:.4f} ± {np.std(nmse_meter_all_score) * 100:.4f}") + print(f"All PSNR: {np.mean(psnr_meter_all_score):.4f} ± {np.std(psnr_meter_all_score):.4f}") + print(f"All SSIM: {np.mean(ssim_meter_all_score):.4f} ± {np.std(ssim_meter_all_score):.4f}") + print("-" * 36) + print(f"Save Path: {save_path}") + + model.train() + return {'NMSE': np.mean(nmse_meter_score), 'PSNR': np.mean(psnr_meter_score), 'SSIM': np.mean(ssim_meter_score)} + + +from dataloaders.new_m4raw_std_dataloader import M4Raw_TestSet as M4Raw_TestSet_new, M4Raw_TrainSet as M4Raw_TrainSet_new + +from dataloaders.m4raw_dataloader import M4Raw_TestSet, M4Raw_TrainSet + + + + +if __name__ == "__main__": + ## make logger file + if use_kspace: + if use_new_dataloader: + snapshot_path = snapshot_path.rstrip("/") + f'_t{num_timesteps}_new_kspace/' + else: + snapshot_path = snapshot_path.rstrip("/") + f'_t{num_timesteps}/' + + + if not isinstance(args.test_tag, type(None)): + snapshot_path = snapshot_path.rstrip("/") + f'_{args.test_tag}/' + + if use_time_model: + snapshot_path = snapshot_path.rstrip("/") + '_time/' + + if not frequency_distortion: + snapshot_path = snapshot_path.rstrip("/") + 'no_distortion/' + + # if not os.path.exists(snapshot_path): + # os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + if use_time_model: + network = DiffTwoBranch(args).cuda() + else: + network = TwoBranch(args).cuda() + + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + + h5_path = "/media/cbtil3/74ec35fd-2452-4dcc-8d7d-3ba957e302c9/m4raw_h5" + os.makedirs(h5_path, exist_ok=True) + debug_predix = "debug_" if DEBUG else "" + + if use_new_dataloader: + db_test = M4Raw_TestSet_new(args, use_kspace=use_kspace, h5_path=os.path.join(h5_path, debug_predix+"test.h5")) # + else: + db_test = M4Raw_TestSet(args.root_path, args.MRIDOWN, use_kspace=use_kspace) + + # db_test = build_dataset(args, mode='val', use_kspace=use_kspace) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + if args.phase == 'test': + + save_mode_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + + print('load weights from ' + save_mode_path) + + try: + checkpoint = torch.load(save_mode_path) + except: + print("Missing keys:", set(model_state_dict.keys()) - set(loaded_state_dict.keys())) + + + weights_dict = {} + for k, v in checkpoint['network'].items(): + new_k = k.replace('module.', '') if 'module' in k else k + weights_dict[new_k] = v + + network.load_state_dict(weights_dict) + network.eval() + + eval_result = evaluate(network, testloader, device, save_path=snapshot_path + '/result_case/') + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/train_brats.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/train_brats.py new file mode 100644 index 0000000000000000000000000000000000000000..8bf7122093072cfbcd488a54c2fef8e1443eff3f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/train_brats.py @@ -0,0 +1,476 @@ +from tqdm import tqdm +from tensorboardX import SummaryWriter +import logging, time, os, sys +import torch.optim as optim +from torchvision import transforms +from torch.utils.data import DataLoader +from dataloaders.BRATS_dataloader_new import Hybrid as MyDataset +from dataloaders.BRATS_dataloader_new import RandomPadCrop, ToTensor +from networks.mynet import TwoBranch +from utils.option import args +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity +import matplotlib.pyplot as plt + +from utils.utils import * +from frequency_diffusion.degradation.k_degradation import get_ksu_kernel, apply_tofre, apply_to_spatial +from networks_time.mynet import DiffTwoBranch + +train_data_path = args.root_path +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +batch_size = args.batch_size * len(args.gpu.split(',')) +max_iterations = args.max_iterations +base_lr = args.base_lr + +# --use_time_model True --use_kspace True --ACCELERATIONS 4 --MRIDOWN 4X --low_field_SNR 20 --input_normalize mean_std +DEBUG = args.DEBUG +use_time_model = args.use_time_model +use_kspace = args.use_kspace + +# kspace_refine = True # Albu with 41.33, w/ 42.00 +# mask_vacant = False +frequency_distortion = True + +saveroot = "image_results/brats" + + +num_timesteps = args.num_timesteps #30 +image_size = args.image_size #240 +distortion_sigma = 10/255 + + +if args.MRIDOWN == "4X": + accelerate_mask = np.load("./dataloaders/example_mask/brats_4X_mask.npy") + accelerate_mask = torch.from_numpy(accelerate_mask).unsqueeze(0).clone().float() +else: + accelerate_mask = None + + + +saveroot = saveroot + "_" + args.MRIDOWN + +# Output a list of k-space kernels +kspace_masks = get_ksu_kernel(num_timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=args.ACCELERATIONS[0], + accelerate_mask=accelerate_mask + ) + +np.save(f"./dataloaders/example_mask/brats_{args.ACCELERATIONS[0]}_kspace_mask.npy", kspace_masks) + + +kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + +print("kspace_masks = ", kspace_masks.shape) # kspace_masks = torch.Size([5, 1, 240, 240]) + +os.makedirs(os.path.dirname(saveroot), exist_ok=True) + +def save_mask(saveroot, kspace_masks): + masks_np = kspace_masks.squeeze(1).cpu().numpy() + + # Create a thin horizontal border (e.g., 5 pixels) between masks + # Save some mask + border_thickness = 5 + border = np.zeros((border_thickness, 240)) # white border (or black if you prefer zeros) + + # Stack with borders in-between + composite = [] + for i, mask in enumerate(masks_np): + composite.append(mask) + if i < len(masks_np) - 1: + composite.append(border) # add border between images + stacked_image = np.vstack(composite) + + os.makedirs(saveroot, exist_ok=True) + plt.imsave(saveroot + "/000_kmask.png", stacked_image, cmap='gray') + +save_mask(saveroot, kspace_masks) + + +if __name__ == "__main__": + ## make logger file + if use_kspace: + snapshot_path = snapshot_path.rstrip("/") + f'_t{num_timesteps}_kspace/' + + if use_time_model: + snapshot_path = snapshot_path.rstrip("/") + '_time/' + + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + if use_time_model: + network = DiffTwoBranch(args).cuda() + else: + network = TwoBranch(args).cuda() + + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + + n_parameters = sum(p.numel() for p in network.parameters() if p.requires_grad) + print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + db_train = MyDataset(split='train', MRIDOWN=args.MRIDOWN, SNR=args.low_field_SNR, + transform=transforms.Compose([RandomPadCrop(), ToTensor()]), + base_dir=train_data_path, input_normalize = args.input_normalize, use_kspace=use_kspace) + + db_test = MyDataset(split='test', MRIDOWN=args.MRIDOWN, SNR=args.low_field_SNR, + transform=transforms.Compose([ToTensor()]), + base_dir=test_data_path, input_normalize = args.input_normalize) + + trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + # fixtrainloader = DataLoader(db_train, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + if args.phase == 'train': + network.train() + + params = list(network.parameters()) + optimizer1 = optim.AdamW(params, lr=base_lr, betas=(0.9, 0.999), weight_decay=1e-4) + if not use_kspace: + scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=20000, gamma=0.5) + else: + scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=40000, gamma=0.5) + + writer = SummaryWriter(snapshot_path + '/log') + + iter_num = 0 + max_epoch = max_iterations // len(trainloader) + 1 + if use_kspace: + max_epoch = max_epoch * num_timesteps + + best_status = {'T1_NMSE': 10000000, 'T1_PSNR': 0, 'T1_SSIM': 0, + 'T2_NMSE': 10000000, 'T2_PSNR': 0, 'T2_SSIM': 0} + + fft_weight = 0.01 + criterion = nn.L1Loss().to(device, non_blocking=True) + freloss = Frequency_Loss().to(device, non_blocking=True) + # lpips_loss = LPIPS().eval().to(device) + start_time = time.time() + mask = None + + progress_bar = tqdm(range(max_epoch), ncols=100) + + for epoch_num in progress_bar: + time1 = time.time() + debug_time = False + network.train() + + for i_batch, (sampled_batch, sample_stats) in enumerate(trainloader): + time2 = time.time() + + t1_in, t1, t2_in, t2 = sampled_batch['image_in'].cuda(), sampled_batch['image'].cuda(), \ + sampled_batch['target_in'].cuda(), sampled_batch['target'].cuda() + + t1_krecon, t2_krecon = sampled_batch['image_krecon'].cuda(), sampled_batch['target_krecon'].cuda() + + # Degradation + if use_kspace: + b = t1_in.shape[0] + t = torch.randint(0, num_timesteps, (b,), device=device).long() + mask = kspace_masks[t] + + target_fft, _ = apply_tofre(t2.clone(), mask) + fft, mask = apply_tofre(t2_in.clone(), mask) + + # if np.random.rand() > (1 / (1 + num_timesteps)): + fft = target_fft * mask + fft * (1 - mask) # Seems too easy + + # Frequency Noise + if frequency_distortion: + fft_magnitude = torch.abs(fft) # 幅度 + fft_phase = torch.angle(fft) # 相位 + + # Add noise to unmasked frequencies to maintain the stochasticity that diffusion models typically rely on. + sigma = distortion_sigma * torch.abs(torch.randn(1)).item() + noise = torch.randn_like(fft_magnitude) * sigma + noise_magnitude = noise * fft_magnitude * mask # + noise * (1 - mask) + fft_magnitude += noise_magnitude + + # sigma = distortion_sigma / 2 * torch.abs(torch.randn(1)).item() + # noise = torch.randn_like(fft_phase) * sigma + # noise_pha = noise * fft_phase * mask # + noise * (1 - mask) + # fft_phase += noise_pha + + fft = fft_magnitude * torch.exp(1j * fft_phase) + + t2_in = apply_to_spatial(fft) + + time3 = time.time() + + if use_time_model and use_kspace: + outputs = network(t2_in, t1_in, t) + else: + outputs = network(t2_in, t1_in) + + spatial_loss = criterion(outputs['img_out'], t2) + criterion(outputs['img_fre'], t2) + fre_loss = fft_weight * freloss(outputs['img_fre'], t2, mask) # + fft_weight * freloss(outputs['img_out'], t2, mask) + loss = spatial_loss + fre_loss #+ lpips_loss(outputs['img_out'], t2) * 0.01 + + + time4 = time.time() + optimizer1.zero_grad() + loss.backward() + + if args.clip_grad == "True": + torch.nn.utils.clip_grad_norm_(network.parameters(), 0.01) + + optimizer1.step() + scheduler1.step() + if debug_time: + print("Optimizer Step Time: ", time.time() - time2) + + time5 = time.time() + + # summary + iter_num = iter_num + 1 + + progress_bar.set_description( + f"Iter {iter_num} | lr: {scheduler1.get_last_lr()[0]:.2e} | s_loss: {spatial_loss.item():.4f} | fre_loss: {fre_loss.item():.4f}" + ) + + if iter_num % 100 == 0: + # logging.info('iteration %d [%.2f sec]: learning rate : %f loss : %f ' % (iter_num, time.time()-start_time, scheduler1.get_lr()[0], loss.item())) + if DEBUG: + break + + if iter_num % 20000 == 0: + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') + torch.save({'network': network.state_dict()}, save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + + if iter_num > max_iterations: + break + time1 = time.time() + + + ## ================ Evaluate ================ + logging.info(f'Epoch {epoch_num} Evaluation:') + # print() + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all = [], [], [] + + t2_MSE_first_step, t2_PSNR_first_step, t2_SSIM_first_step = [], [], [] + + t1_MSE_krecon, t1_PSNR_krecon, t1_SSIM_krecon = [], [], [] + t2_MSE_krecon, t2_PSNR_krecon, t2_SSIM_krecon = [], [], [] + ids = 0 + + middle_results = [] + test_id_max = 3 + network.eval() + + for (sampled_batch, sample_stats) in testloader: + test_id_max -= 1 + if test_id_max < 0: + break + print("------ Test ID ------:", 3 - test_id_max) + + with torch.no_grad(): + t1_in, t1, t2_in, t2 = sampled_batch['image_in'].cuda(), sampled_batch['image'].cuda(), \ + sampled_batch['target_in'].cuda(), sampled_batch['target'].cuda() + + t1_krecon, t2_krecon = sampled_batch['image_krecon'].cuda(), sampled_batch['target_krecon'].cuda() + # t_merge = torch.cat([t1_in, t2_in], dim=1) + + + if use_kspace: + t = num_timesteps - 1 + t = torch.tensor([t], device=device).long() + + mask = kspace_masks[t] + target_fft, _ = apply_tofre(t2.clone(), mask) + fft, mask = apply_tofre(t2_in.clone(), mask) + + fft = target_fft * mask + fft * (1 - mask) # Seems too easy + t2_in = apply_to_spatial(fft) + origin_t2_in = t2_in.clone() + + + while t >= 0: + if use_time_model: + outputs = network(t2_in, t1_in, t)['img_out'] # t1 + else: + outputs = network(t2_in, t1_in)['img_out'] + + if t == num_timesteps - 1: + first_step_recon = outputs + + + if t == 0: + mask = kspace_masks[0] # last one + t2_in = outputs + + else: + k_full = kspace_masks[-1] # True + t2_in_fre, k_full = apply_tofre(t2_in, k_full) + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] # get_kspace_kernels(t - 2).cuda() + kt = kspace_masks[t] # current one + k_residual = kt_sub_1 - kt + + recon_sample_fre, k_residual = apply_tofre(outputs, k_residual) + + t2_in_fre = t2_in_fre * (1 - k_residual) + recon_sample_fre * k_residual # substitute + + outputs = apply_to_spatial(t2_in_fre) + t2_in = outputs + + + if args.input_normalize == "mean_std": + t2_in_out = t2_in.clone().detach().cpu() + t2_in_out = (t2_in_out - sample_stats['t2_mean']) / sample_stats['t2_std'] + middle_results.append( (t2_in_out - t2_in_out.min())/(t2_in_out.max() - t2_in_out.min()) ) # Normalize to [0, 1] + + else: + middle_results.append( (t2_in - t2_in.min())/(t2_in.max() - t2_in.min()) ) # Normalize to [0, 1] + + + t = t - 1 + + t2_out = t2_in + else: + t2_out = network(t2_in, t1)['img_out'] + + t1_out = None + + if args.input_normalize == "mean_std": + t1_mean = sample_stats['t1_mean'].data.cpu().numpy()[0] + t1_std = sample_stats['t1_std'].data.cpu().numpy()[0] + t2_mean = sample_stats['t2_mean'].data.cpu().numpy()[0] + t2_std = sample_stats['t2_std'].data.cpu().numpy()[0] + + if t1_out is not None: + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_out_img = (np.clip(t1_out.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_krecon_img = (np.clip(t1_krecon.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + + origin_t2_in = (np.clip(origin_t2_in.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_img = (np.clip(t2.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_out_img = (np.clip(t2_out.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_krecon_img = (np.clip(t2_krecon.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_first_step_recon_img = (np.clip(first_step_recon.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + + else: + if t1_out is not None: + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t1_out_img = (np.clip(t1_out.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t1_krecon_img = (np.clip(t1_krecon.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + + origin_t2_in = (np.clip(origin_t2_in.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t2_img = (np.clip(t2.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t2_out_img = (np.clip(t2_out.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t2_krecon_img = (np.clip(t2_krecon.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t2_first_step_recon_img = (np.clip(first_step_recon.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + + if t1_out is not None: + + MSE = mean_squared_error(t1_img, t1_out_img) + PSNR = peak_signal_noise_ratio(t1_img, t1_out_img) + SSIM = structural_similarity(t1_img, t1_out_img) + t1_MSE_all.append(MSE) + t1_PSNR_all.append(PSNR) + t1_SSIM_all.append(SSIM) + + MSE = mean_squared_error(t1_img, t1_krecon_img) + PSNR = peak_signal_noise_ratio(t1_img, t1_krecon_img) + SSIM = structural_similarity(t1_img, t1_krecon_img) + t1_MSE_krecon.append(MSE) + t1_PSNR_krecon.append(PSNR) + t1_SSIM_krecon.append(SSIM) + + + if t2_out is not None: + MSE = mean_squared_error(t2_img, t2_out_img) + PSNR = peak_signal_noise_ratio(t2_img, t2_out_img) + SSIM = structural_similarity(t2_img, t2_out_img) + t2_MSE_all.append(MSE) + t2_PSNR_all.append(PSNR) + t2_SSIM_all.append(SSIM) + # print("[t2 MRI] MSE:", MSE, "PSNR:", PSNR, "SSIM:", SSIM) + + MSE = mean_squared_error(t2_img, t2_first_step_recon_img) + PSNR = peak_signal_noise_ratio(t2_img, t2_first_step_recon_img) + SSIM = structural_similarity(t2_img, t2_first_step_recon_img) + t2_MSE_first_step.append(MSE) + t2_PSNR_first_step.append(PSNR) + t2_SSIM_first_step.append(SSIM) + + MSE = mean_squared_error(t2_img, t2_krecon_img) + PSNR = peak_signal_noise_ratio(t2_img, t2_krecon_img) + SSIM = structural_similarity(t2_img, t2_krecon_img) + t2_MSE_krecon.append(MSE) + t2_PSNR_krecon.append(PSNR) + t2_SSIM_krecon.append(SSIM) + + ids += 1 + if ids > 360: + break + + if t1_out is not None: + t1_mse = np.array(t1_MSE_all).mean() + t1_psnr = np.array(t1_PSNR_all).mean() + t1_ssim = np.array(t1_SSIM_all).mean() + + t1_krecon_mse = np.array(t1_MSE_krecon).mean() + t1_krecon_psnr = np.array(t1_PSNR_krecon).mean() + t1_krecon_ssim = np.array(t1_SSIM_krecon).mean() + + t2_mse = np.array(t2_MSE_all).mean() + t2_psnr = np.array(t2_PSNR_all).mean() + t2_ssim = np.array(t2_SSIM_all).mean() + + t2_first_step_mse = np.array(t2_MSE_first_step).mean() + t2_first_step_psnr = np.array(t2_PSNR_first_step).mean() + t2_first_step_ssim = np.array(t2_SSIM_first_step).mean() + + t2_krecon_mse = np.array(t2_MSE_krecon).mean() + t2_krecon_psnr = np.array(t2_PSNR_krecon).mean() + t2_krecon_ssim = np.array(t2_SSIM_krecon).mean() + + + # Test Visualization + test_id_max = 0 + os.makedirs(saveroot, exist_ok=True) + middle_results = torch.concat(middle_results, dim=-1).squeeze() # Concatenate along the batch dimension + middle_results = middle_results.detach().cpu().numpy() + plt.imsave(os.path.join(saveroot, f"{iter_num}-{test_id_max}-middle.png"), middle_results, cmap='gray') + + + + results = np.concatenate((origin_t2_in, t2_img, t2_out_img, t2_first_step_recon_img), axis=1) + plt.imsave(os.path.join(saveroot, f"{iter_num}-{test_id_max}-compare.png"), results, cmap='gray') + + + + if t2_psnr > best_status['T2_PSNR']: + best_status = {'T2_NMSE': t2_mse, 'T2_PSNR': t2_psnr, 'T2_SSIM': t2_ssim} + best_checkpoint_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + + torch.save({'network': network.state_dict()}, best_checkpoint_path) + print('New Best Network:') + + logging.info(f"[T2 First MRI:] average MSE: {t2_first_step_mse} average PSNR: {t2_first_step_psnr} average SSIM: {t2_first_step_ssim}") + logging.info(f"[T2 MRI:] average MSE: {t2_mse} average PSNR: {t2_psnr} average SSIM: {t2_ssim}") + print("Snapshot_path = ", snapshot_path) + + if iter_num > max_iterations: + break + + print(best_status) + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(max_iterations) + '.pth') + torch.save({'network': network.state_dict()}, + save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + writer.close() \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/train_fastmri.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/train_fastmri.py new file mode 100644 index 0000000000000000000000000000000000000000..c764a19fccfce0b55dd6847d4d703d0f6798587e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/train_fastmri.py @@ -0,0 +1,452 @@ +import os +import sys +from tqdm import tqdm +from tensorboardX import SummaryWriter + +import logging +import time +import torch.optim as optim +from torch.utils.data import DataLoader +from networks.mynet import TwoBranch +from networks_time.mynet import DiffTwoBranch + +from utils.option import args + +from dataloaders.fastmri import build_dataset +from frequency_diffusion.degradation.k_degradation import get_ksu_kernel, apply_tofre, apply_to_spatial +from utils.lpips import LPIPS +from utils.metric import nmse, psnr, ssim, AverageMeter +from collections import defaultdict +import matplotlib.pyplot as plt + + +train_data_path = args.root_path +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +batch_size = args.batch_size * len(args.gpu.split(',')) +max_iterations = args.max_iterations +base_lr = args.base_lr +from utils.utils import * + +# --num_timesteps 30 --image_size 320 --use_kspace True --use_time_model True --ACCELERATIONS 4 --gpu 0 --phase train + +DEBUG = args.DEBUG +use_kspace = args.use_kspace +frequency_distortion = True + + +use_time_model = args.use_time_model +num_timesteps = args.num_timesteps + +image_size = 320 +distortion_sigma = 10 / 255 + + +saveroot = "image_results/fastmri" + "_" + args.MRIDOWN +print("saveroot=", saveroot) + +os.makedirs(os.path.dirname(saveroot), exist_ok=True) + + +if args.phase == 'test': + kspace_masks = np.load(f"./dataloaders/example_mask/kspace_{args.ACCELERATIONS[0]}_mask.npy") + kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + +else: + # Output a list of k-space kernels + kspace_masks = get_ksu_kernel(num_timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=args.ACCELERATIONS[0], + ) # args.ACCELERATIONS = [4] or [8] + + np.save(f"./dataloaders/example_mask/kspace_{args.ACCELERATIONS[0]}_mask.npy", kspace_masks) + + +kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + +print("kspace kernels shape:", kspace_masks.shape) # (1, 1, 320, 320) + + +def save_mask(saveroot, kspace_masks): + masks_np = kspace_masks.squeeze(1).cpu().numpy() + + # Create a thin horizontal border (e.g., 5 pixels) between masks + # Save some mask + border_thickness = 5 + border = np.zeros((border_thickness, image_size)) # white border (or black if you prefer zeros) + + # Stack with borders in-between + composite = [] + for i, mask in enumerate(masks_np): + composite.append(mask) + if i < len(masks_np) - 1: + composite.append(border) # add border between images + stacked_image = np.vstack(composite) + + os.makedirs(saveroot, exist_ok=True) + plt.imsave(saveroot + "/000_kmask.png", stacked_image, cmap='gray') + +save_mask(saveroot, kspace_masks) + + +@torch.no_grad() +def evaluate(model, data_loader, device): + model.eval() + + nmse_meter, psnr_meter, ssim_meter = AverageMeter(), AverageMeter(), AverageMeter() + direct_nmse, direct_psnr, direct_ssim = AverageMeter(), AverageMeter(), AverageMeter() + output_dic = defaultdict(dict) + target_dic = defaultdict(dict) + input_dic = defaultdict(dict) + direct_dic = defaultdict(dict) + + for id, data in enumerate(data_loader): + pd, pdfs, _ = data + name = os.path.basename(pdfs[4][0]).split('.')[0] + + target = pdfs[1].to(device) + mean, std = pdfs[2], pdfs[3] + + fname = pdfs[4] + slice_num = pdfs[5] + + mean = mean.unsqueeze(1).unsqueeze(2).to(device) + std = std.unsqueeze(1).unsqueeze(2).to(device) + + pd_img = pd[1].unsqueeze(1).to(device) + + pdfs_img = pdfs[0].unsqueeze(1).to(device) + + pdfs_img_origin = pdfs_img.clone() + + # Degradation + if use_kspace: + b = target.size(0) + t = num_timesteps - 1 + t = torch.tensor([t], device=device).long() # t-1 + + # t = torch.randint(num_timesteps - 1, num_timesteps, (b,), device=device).long() # t-1 + + mask = kspace_masks[t] + fft, mask = apply_tofre(target.clone(), mask) + fft = fft * mask + 0.0 + pdfs_img = apply_to_spatial(fft) + + pdfs_img_origin = pdfs_img.clone() + + middle_results = [] + while t >= 0: + with torch.no_grad(): + if use_time_model: + outputs = model(pdfs_img, pd_img, t)['img_out'] + else: + outputs = model(pdfs_img, pd_img)['img_out'] + + if t == num_timesteps - 1: + direct_recon = outputs + + if t == 0: + mask = kspace_masks[0] # last one + pdfs_img = outputs + + else: + k_full = kspace_masks[-1] + faded_recon_sample_fre, k_full = apply_tofre(pdfs_img, k_full) + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t - 1] # get_kspace_kernels(t - 2).cuda() + kt = kspace_masks[t] # self.get_kspace_kernels(t - 1).cuda() # last one + k_residual = kt_sub_1 - kt + + recon_sample_fre, k_residual = apply_tofre(outputs, k_residual) + + faded_recon_sample_fre = faded_recon_sample_fre * (1 - k_residual) + recon_sample_fre * k_residual + + outputs = apply_to_spatial(faded_recon_sample_fre) + pdfs_img = outputs + + if args.input_normalize == "mean_std": + pdfs_img_out = pdfs_img.clone().detach().cpu() + pdfs_img_out = pdfs_img * std + mean + middle_results.append( (pdfs_img_out - pdfs_img_out.min())/(pdfs_img_out.max() - pdfs_img_out.min()) ) # Normalize to [0, 1] + + else: + middle_results.append( (pdfs_img - pdfs_img.min())/(pdfs_img.max() - pdfs_img.min()) ) # Normalize to [0, 1] + + t = t - 1 + + else: + outputs = model(pdfs_img, pd_img)['img_out'] + + # print("-----\nori target = ", target.max(), target.min()) + # print("ori inputs = ", inputs.max(), inputs.min()) + # print("ori outputs = ", outputs.max(), outputs.min()) + # print("ori direct_recon = ", direct_recon.max(), direct_recon.min()) + + + target = target * std + mean + inputs = pdfs_img.squeeze(1) * std + mean + outputs = outputs.squeeze(1) * std + mean + direct_recon = direct_recon.squeeze(1) * std + mean + + # print("-----\ntarget = ", target.max(), target.min()) + # print("inputs = ", inputs.max(), inputs.min()) + # print("outputs = ", outputs.max(), outputs.min()) + # print("direct_recon = ", direct_recon.max(), direct_recon.min()) + + min, max = target.min(), target.max() + target_mm = (target - min) / (max - min) + inputs_mm = (inputs - min) / (max - min) + outputs_mm = (outputs - min) / (max - min) + direct_recon_mm = (direct_recon - min) / (max - min) + + inputs_mm = torch.clamp(inputs_mm, 0, 1) + outputs_mm = torch.clamp(outputs_mm, 0, 1) + direct_recon_mm = torch.clamp(direct_recon_mm, 0, 1) + + if id < 3: + os.makedirs(saveroot, exist_ok=True) + middle_results = torch.concat(middle_results, dim=-1).squeeze() # Concatenate along the batch dimension + middle_results = middle_results.detach().cpu().numpy() + plt.imsave(os.path.join(saveroot, f"{iter_num}-{id}-middle.png"), middle_results, cmap='gray') + + + results = torch.concat((inputs_mm, target_mm, outputs_mm, direct_recon_mm), dim=2).detach().cpu().numpy()[0] + plt.imsave(os.path.join(saveroot, f"{iter_num}-{id}-compare.png"), results, cmap='gray') + + + for i, f in enumerate(fname): + output_dic[f][slice_num[i]] = outputs[i] + target_dic[f][slice_num[i]] = target[i] + input_dic[f][slice_num[i]] = inputs[i] + direct_dic[f][slice_num[i]] = direct_recon[i] + + if id > 100: + break + + for name in output_dic.keys(): + f_output = torch.stack([v for _, v in output_dic[name].items()]) + f_target = torch.stack([v for _, v in target_dic[name].items()]) + our_nmse = nmse(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_psnr = psnr(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_ssim = ssim(f_target.cpu().numpy(), f_output.cpu().numpy()) + + nmse_meter.update(our_nmse, 1) + psnr_meter.update(our_psnr, 1) + ssim_meter.update(our_ssim, 1) + + direct_nmse.update( + nmse(f_target.cpu().numpy(), torch.stack([v for _, v in direct_dic[name].items()]).cpu().numpy()), 1) + direct_psnr.update( + psnr(f_target.cpu().numpy(), torch.stack([v for _, v in direct_dic[name].items()]).cpu().numpy()), 1) + direct_ssim.update( + ssim(f_target.cpu().numpy(), torch.stack([v for _, v in direct_dic[name].items()]).cpu().numpy()), 1) + + print("==> Evaluate Metric") + print("Direct Results ----------") + print("NMSE: {:.4}".format(direct_nmse.avg)) + print("PSNR: {:.4}".format(direct_psnr.avg)) + print("SSIM: {:.4}".format(direct_ssim.avg)) + print("------------------") + + print("==> Evaluate Metric") + print("Results ----------") + print("NMSE: {:.4}".format(nmse_meter.avg)) + print("PSNR: {:.4}".format(psnr_meter.avg)) + print("SSIM: {:.4}".format(ssim_meter.avg)) + print("------------------") + model.train() + + return {'NMSE': nmse_meter.avg, 'PSNR': psnr_meter.avg, 'SSIM': ssim_meter.avg} + + +if __name__ == "__main__": + ## make logger file + if use_kspace: + snapshot_path = snapshot_path.rstrip("/") + f'_t{num_timesteps}_kspace/' + if use_time_model: + snapshot_path = snapshot_path.rstrip("/") + '_time/' + + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + if use_time_model: + network = DiffTwoBranch(args).cuda() + else: + network = TwoBranch(args).cuda() + # network = build_model_from_name(args).cuda() + device = torch.device('cuda') + network.to(device) + lpips_loss = LPIPS().eval().to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + # network = nn.SyncBatchNorm.convert_sync_batchnorm(network) + + n_parameters = sum(p.numel() for p in network.parameters() if p.requires_grad) + print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + db_train = build_dataset(args, mode='train', use_kspace=use_kspace) + db_test = build_dataset(args, mode='val', use_kspace=use_kspace) + + trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + mask = None + + if args.phase == 'train': + network.train() + + params = list(network.parameters()) + optimizer1 = optim.AdamW(params, lr=base_lr, betas=(0.9, 0.999), weight_decay=1e-4) + scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=20000, gamma=0.5) + + writer = SummaryWriter(snapshot_path + '/log') + + iter_num = 0 + max_epoch = max_iterations // len(trainloader) + 1 + + best_status = {'NMSE': 10000000, 'PSNR': 0, 'SSIM': 0} + + fft_weight = 0.01 + criterion = nn.L1Loss().to(device, non_blocking=True) + freloss = Frequency_Loss().to(device, non_blocking=True) + + progress_bar = tqdm(range(max_epoch), ncols=100) + + for epoch_num in progress_bar: + time1 = time.time() + start_time = time.time() + network.train() + + for i_batch, sampled_batch in enumerate(trainloader): + time2 = time.time() + + pd, pdfs, _ = sampled_batch + target = pdfs[1] + + mean, std = pdfs[2], pdfs[3] + + pd_img = pd[1].unsqueeze(1) + pdfs_img = pdfs[0].unsqueeze(1) + target = target.unsqueeze(1) + + b = pd_img.size(0) + + pd_img = pd_img.to(device) # [4, 1, 320, 320] + pdfs_img = pdfs_img.to(device) # [4, 1, 320, 320] + target = target.to(device) # [4, 1, 320, 320] + + time3 = time.time() + + # Degradation + if use_kspace: + t = torch.randint(0, num_timesteps, (b,), device=device).long() + mask = kspace_masks[t] + + fft, mask = apply_tofre(target.clone(), mask) + fft = fft * mask + + # Frequency Noise + if frequency_distortion: + fft_magnitude = torch.abs(fft) # 幅度 + fft_phase = torch.angle(fft) # 相位 + + # Add noise to unmasked frequencies to maintain the stochasticity that diffusion models typically rely on. + sigma = distortion_sigma * torch.abs(torch.randn(1)).item() + noise = torch.randn_like(fft_magnitude) * sigma + noise_magnitude = noise * fft_magnitude * mask # + noise * (1 - mask) + fft_magnitude += noise_magnitude + + # sigma = distortion_sigma / 2 * torch.abs(torch.randn(1)).item() + # noise = torch.randn_like(fft_phase) * sigma + # noise_pha = noise * fft_phase * mask # + noise * (1 - mask) + # fft_phase += noise_pha + + fft = fft_magnitude * torch.exp(1j * fft_phase) + + pdfs_img = apply_to_spatial(fft) + + # breakpoint() + if use_time_model: + outputs = network(pdfs_img, pd_img, t) + else: + outputs = network(pdfs_img, pd_img) + + + spatial_loss = criterion(outputs['img_out'], target) + criterion(outputs['img_fre'], target) + fre_loss = fft_weight * freloss(outputs['img_fre'], target, mask) #+ fft_weight * freloss(outputs['img_out'], target, mask) + loss = spatial_loss + fre_loss + fre_loss # + 0.01 * lpips_loss(outputs['img_out'], target).mean() + + # 0.01 * lpips_loss(outputs['img_out'], target).mean() # + + time4 = time.time() + + optimizer1.zero_grad() + loss.backward() + + if args.clip_grad == "True": + ### clip the gradients to a small range. + torch.nn.utils.clip_grad_norm_(network.parameters(), 0.01) + + optimizer1.step() + scheduler1.step() + + time5 = time.time() + + # summary + iter_num = iter_num + 1 + + progress_bar.set_description( + f"Iter {iter_num} | lr: {scheduler1.get_last_lr()[0]:.2e} | s_loss: {spatial_loss.item():.4f} | fre_loss: {fre_loss.item():.4f}" + ) + + print_iter = 100 # if not DEBUG else 5 + if iter_num % print_iter == 0 and DEBUG: + break + # logging.info('iteration %d [%.2f sec]: learning rate : %f loss : %f ' % ( + # iter_num, time.time() - start_time, scheduler1.get_lr()[0], loss.item())) + # if : + # break + + if iter_num % 20000 == 0: + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') + torch.save({'network': network.state_dict()}, save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + + if iter_num > max_iterations: + break + time1 = time.time() + + ## ================ Evaluate ================ + logging.info(f'\nEpoch {epoch_num} Evaluation:') + # print() + network.eval() + eval_result = evaluate(network, testloader, device) + + if eval_result['PSNR'] > best_status['PSNR']: + best_status = {'NMSE': eval_result['NMSE'], 'PSNR': eval_result['PSNR'], 'SSIM': eval_result['SSIM']} + best_checkpoint_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + + torch.save({'network': network.state_dict()}, best_checkpoint_path) + print('New Best Network saved:', best_checkpoint_path) + + logging.info( + f"average MSE: {eval_result['NMSE']} average PSNR: {eval_result['PSNR']} average SSIM: {eval_result['SSIM']}") + print("Snapshot Path: ", snapshot_path) + + if iter_num > max_iterations: + break + + print(best_status) + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(max_iterations) + '.pth') + torch.save({'network': network.state_dict()}, + save_mode_path) + print("save model to {}".format(save_mode_path)) + writer.close() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/train_m4raw.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/train_m4raw.py new file mode 100644 index 0000000000000000000000000000000000000000..5148a933d3268bd4b536a6462c1169a02f4689fe --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/train_m4raw.py @@ -0,0 +1,521 @@ +import os +import sys +from tqdm import tqdm +from tensorboardX import SummaryWriter +import logging +import time +import torch.optim as optim +from torch.utils.data import DataLoader +from networks.mynet import TwoBranch +from networks_time.mynet import DiffTwoBranch + +from utils.option import args +import matplotlib.pyplot as plt + +use_new_dataloader = True + +if use_new_dataloader: + from dataloaders.new_m4raw_std_dataloader import M4Raw_TestSet, M4Raw_TrainSet, normalize, normalize_instance_dim +else: + from dataloaders.m4raw_dataloader import M4Raw_TestSet, M4Raw_TrainSet, normalize, normalize_instance_dim + +from frequency_diffusion.degradation.k_degradation import get_ksu_kernel, apply_tofre, apply_to_spatial +from utils.lpips import LPIPS +from utils.metric import nmse, psnr, ssim, AverageMeter +from collections import defaultdict +# import imsave +from skimage.io import imsave + +train_data_path = args.root_path +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +batch_size = args.batch_size * len(args.gpu.split(',')) +max_iterations = args.max_iterations +base_lr = args.base_lr +from utils.utils import * + + + + + +num_timesteps = args.num_timesteps +image_size = args.image_size +distortion_sigma = 10 / 255 # 5/255 + + + +DEBUG = args.DEBUG +use_kspace = args.use_kspace +use_time_model = args.use_time_model + + + +num_timesteps = args.num_timesteps +image_size = args.image_size #32 + + +frequency_distortion = True + + +# Output a list of k-space kernels +kspace_masks = get_ksu_kernel(num_timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=args.ACCELERATIONS[0]) + + +np.save(f"./dataloaders/example_mask/m4raw_{args.ACCELERATIONS[0]}_mask.npy", kspace_masks) + +kspace_masks = torch.from_numpy(np.asarray(kspace_masks)).cuda() + + +saveroot = "image_results/m4raw" + "_" + args.ACCELERATIONS + "X" + +os.makedirs(os.path.dirname(saveroot), exist_ok=True) + + + +def save_mask(saveroot, kspace_masks): + masks_np = kspace_masks.squeeze(1).cpu().numpy() + + # Create a thin horizontal border (e.g., 5 pixels) between masks + # Save some mask + border_thickness = 5 + border = np.zeros((border_thickness, image_size)) # white border (or black if you prefer zeros) + + # Stack with borders in-between + composite = [] + for i, mask in enumerate(masks_np): + composite.append(mask) + if i < len(masks_np) - 1: + composite.append(border) # add border between images + stacked_image = np.vstack(composite) + + os.makedirs(saveroot, exist_ok=True) + plt.imsave(saveroot + "/000_kmask.png", stacked_image, cmap='gray') + +save_mask(saveroot, kspace_masks) + + + +@torch.no_grad() +def evaluate(iter_num, model, data_loader, device): + model.eval() + print_i = 1 + + nmse_meter = AverageMeter() + psnr_meter = AverageMeter() + ssim_meter = AverageMeter() + first_psnr_meter = AverageMeter() + first_ssim_meter = AverageMeter() + first_nmse_meter = AverageMeter() + + output_dic = defaultdict(dict) + target_dic = defaultdict(dict) + input_dic = defaultdict(dict) + first_dic = defaultdict(dict) + + for id, sampled_batch in enumerate(data_loader): + + if use_new_dataloader: + t1_img, t1_in = sampled_batch['t1'], sampled_batch['t1_in'] + t2_img, t2_in = sampled_batch['t2'], sampled_batch['t2_in'] + else: + t1_img, t1_in = sampled_batch['ref_image_full'], sampled_batch['ref_image_sub'] + t2_img, t2_in = sampled_batch['tag_image_full'], sampled_batch['tag_image_sub'] + + + t1_img = t1_img.to(device) + t2_img = t2_img.to(device) + + mean, std = sampled_batch['t2_mean'], sampled_batch['t2_std'] + + fname = sampled_batch['fname'] + slice_num = sampled_batch['slice'] + + mean = mean.unsqueeze(1) #.to(device) + std = std.unsqueeze(1) #.to(device) + + t2_in_origin = t2_img.clone() + + middle_results = [] + # Degradation + if use_kspace: + t = num_timesteps - 1 + t = torch.tensor([t], device=device).long() # t-1 + + mask = kspace_masks[t] + fft, mask = apply_tofre(t2_img.clone(), mask) # t2_img | t2_in + + fft = fft * mask + 0.0 + + t2_in = apply_to_spatial(fft) + t2_in_origin = t2_in.clone() + + while t >= 0: + with torch.no_grad(): + if use_time_model: + outputs = model(t2_in, t1_img, t)['img_out'] + else: + outputs = model(t2_in, t1_img)['img_out'] + + + if t == num_timesteps - 1: + first_step_recon = outputs + + if t == 0: + mask = kspace_masks[0] # last one + t2_in = outputs + + else: + k_full = kspace_masks[-1] + t2_in_fre, k_full = apply_tofre(t2_in, k_full) + + with torch.no_grad(): + + kt_sub_1 = kspace_masks[t-1] #get_kspace_kernels(t - 2).cuda() + kt = kspace_masks[t] #self.get_kspace_kernels(t - 1).cuda() # last one + k_residual = kt_sub_1 - kt + + recon_sample_fre, k_residual = apply_tofre(outputs, k_residual) + + + t2_in_fre = t2_in_fre * (1 - k_residual) + recon_sample_fre * k_residual # substitute + + # faded_recon_sample_fre = faded_recon_sample_fre * kt_sub_1 + outputs = apply_to_spatial(t2_in_fre) + t2_in = outputs + + if args.input_normalize == "mean_std": + t2_in_out = t2_in.clone().detach().cpu() + t2_in_out = t2_in_out * std + mean #(t2_in_out - mean) / std + middle_results.append( (t2_in_out - t2_in_out.min())/(t2_in_out.max() - t2_in_out.min()) ) # Normalize to [0, 1] + + else: + middle_results.append( (t2_in - t2_in.min())/(t2_in.max() - t2_in.min()) ) # Normalize to [0, 1] + + t = t - 1 + + else: + outputs = model(t2_in, t1_img)['img_out'] + + if print_i: + t2_in_save = torch.cat([t2_in, t2_in_origin, t2_img], dim=3).cpu().numpy()[0, 0] + + t2_in_save = (t2_in_save - t2_in_save.min()) / (t2_in_save.max() - t2_in_save.min()) + # t2_in_save = np.stack([t2_in_save, t2_in_save, t2_in_save], axis=2) + # save to file + os.makedirs("./debug", exist_ok=True) + save_path = f"./debug/{use_kspace}_{fname[0]}_{slice_num[0]}.png" + plt.imsave(save_path, t2_in_save, cmap='gray') + print_i = 0 + + + # -1 ~ 6 + # print("-----\nori target = ", t2_img.max(), t2_img.min()) + # print("ori inputs = ", t2_in_origin.max(), t2_in_origin.min()) + # print("ori outputs = ", outputs.max(), outputs.min()) + # print("ori direct_recon = ", first_step_recon.max(), first_step_recon.min()) + + + + t2_img = t2_img.squeeze(1).cpu() * std + mean + inputs = t2_in_origin.squeeze(1).cpu() * std + mean + outputs = outputs.squeeze(1).cpu() * std + mean + first_step_recon = first_step_recon.squeeze(1).cpu() * std + mean + + # -2 ~ 100? + # print("-----\ntarget = ", t2_img.max(), t2_img.min()) + # print("inputs = ", inputs.max(), inputs.min()) + # print("outputs = ", outputs.max(), outputs.min()) + # print("direct_recon = ", first_step_recon.max(), first_step_recon.min()) + + + min, max = t2_img.min(), t2_img.max() + t2_img_mm = (t2_img - min) / (max - min) + inputs_mm = (inputs - min) / (max - min) + outputs_mm = (outputs - min) / (max - min) + first_step_recon_mm = (first_step_recon - min) / (max - min) + + + inputs_mm = torch.clamp(inputs_mm, 0, 1) + outputs_mm = torch.clamp(outputs_mm, 0, 1) + first_step_recon_mm = torch.clamp(first_step_recon_mm, 0, 1) + + + + if id < 5: + os.makedirs(saveroot, exist_ok=True) + middle_results = torch.concat(middle_results, dim=-1).squeeze() # Concatenate along the batch dimension + middle_results = middle_results.detach().cpu().numpy() + plt.imsave(os.path.join(saveroot, f"{iter_num}-{id}-middle.png"), middle_results, cmap='gray') + + + + results = torch.concat((inputs_mm, t2_img_mm, outputs_mm, first_step_recon_mm), dim=2).detach().cpu().numpy()[0] + plt.imsave(os.path.join(saveroot, f"{iter_num}-{id}-compare.png"), results, cmap='gray') + + + # t2_img, outputs, inputs, first_step_recon + + for i, f in enumerate(fname): + + output_dic[f][slice_num[i]] = outputs[i] + target_dic[f][slice_num[i]] = t2_img[i] + input_dic[f][slice_num[i]] = inputs[i] + first_dic[f][slice_num[i]] = first_step_recon[i] + + + if id > 360: + break + + + + for name in output_dic.keys(): + f_output = torch.stack([v for _, v in output_dic[name].items()]) + f_target = torch.stack([v for _, v in target_dic[name].items()]) + f_first = torch.stack([v for _, v in first_dic[name].items()]) + + + # Range? ~100 + our_nmse = nmse(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_psnr = psnr(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_ssim = ssim(f_target.cpu().numpy(), f_output.cpu().numpy()) + + nmse_meter.update(our_nmse, 1) + psnr_meter.update(our_psnr, 1) + ssim_meter.update(our_ssim, 1) + + first_nmse = nmse(f_target.cpu().numpy(), f_first.cpu().numpy()) + first_psnr = psnr(f_target.cpu().numpy(), f_first.cpu().numpy()) + first_ssim = ssim(f_target.cpu().numpy(), f_first.cpu().numpy()) + + first_nmse_meter.update(first_nmse, 1) + first_psnr_meter.update(first_psnr, 1) + first_ssim_meter.update(first_ssim, 1) + + print("==> First Step Metric") + print("Results ----------") + print("NMSE: {:.4}".format(first_nmse_meter.avg)) + print("PSNR: {:.4}".format(first_psnr_meter.avg)) + print("SSIM: {:.4}".format(first_ssim_meter.avg)) + print("------------------") + + print("==> Evaluate Metric") + print("Results ----------") + print("NMSE: {:.4}".format(nmse_meter.avg)) + print("PSNR: {:.4}".format(psnr_meter.avg)) + print("SSIM: {:.4}".format(ssim_meter.avg)) + print("------------------") + model.train() + + return {'NMSE': nmse_meter.avg, 'PSNR': psnr_meter.avg, 'SSIM':ssim_meter.avg} + + + +if __name__ == "__main__": + ## make logger file + if use_kspace: + if use_new_dataloader: + snapshot_path = snapshot_path.rstrip("/") + f'_t{num_timesteps}_new_kspace/' + else: + snapshot_path = snapshot_path.rstrip("/") + f'_t{num_timesteps}/' + + + if not isinstance(args.test_tag, type(None)): + snapshot_path = snapshot_path.rstrip("/") + f'_{args.test_tag}/' + + if use_time_model: + snapshot_path = snapshot_path.rstrip("/") + '_time/' + + if not frequency_distortion: + snapshot_path = snapshot_path.rstrip("/") + 'no_distortion/' + + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + if use_time_model: + model = DiffTwoBranch(args).cuda() + else: + model = TwoBranch(args).cuda() + + # model = build_model_from_name(args).cuda() + device = torch.device('cuda') + model.to(device) + lpips_loss = LPIPS().eval().to(device) + + if len(args.gpu.split(',')) > 1: + model = nn.DataParallel(model) + # model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + debug_predix = "debug_" if DEBUG else "" + + + h5_path = "/media/cbtil3/74ec35fd-2452-4dcc-8d7d-3ba957e302c9/m4raw_h5" + os.makedirs(h5_path, exist_ok=True) + + db_train = M4Raw_TrainSet(args, use_kspace=use_kspace, DEBUG=DEBUG, h5_path=os.path.join(h5_path, debug_predix+"train.h5")) # build_dataset(args, mode='train') + db_test = M4Raw_TestSet(args, use_kspace=use_kspace, DEBUG=DEBUG, h5_path=os.path.join(h5_path, debug_predix+"test.h5")) # + + trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + if args.phase == 'train': + model.train() + + params = list(model.parameters()) + optimizer1 = optim.AdamW(params, lr=base_lr, betas=(0.9, 0.999), weight_decay=1e-4) + + if not use_kspace: + scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=20000, gamma=0.5) + else: + scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=40000, gamma=0.5) + + writer = SummaryWriter(snapshot_path + '/log') + + iter_num = 0 + max_epoch = max_iterations // len(trainloader) + 1 + + best_status = {'NMSE': 10000000, 'PSNR': 0, 'SSIM': 0} + + fft_weight=0.01 + criterion = nn.MSELoss().to(device, non_blocking=True) #nn.L1Loss().to(device, non_blocking=True) + freloss = Frequency_Loss().to(device, non_blocking=True) + t = 0 + + for epoch_num in tqdm(range(max_epoch), ncols=70): + time1 = time.time() + start_time = time.time() + model.train() + + progress_bar = tqdm(enumerate(trainloader), ncols=100, desc=f'Epoch {epoch_num + 1}/{max_epoch}', unit='it') + + for i_batch, sampled_batch in progress_bar: + time2 = time.time() + + # T1 is the reference image, T2 is the target image + t1_img, t1_in = sampled_batch['t1'], sampled_batch['t1_in'] + t2_img, t2_in = sampled_batch['t2'], sampled_batch['t2_in'] + + t1_img = t1_img.to(device) + t1_in = t1_in.to(device) + t2_img = t2_img.to(device) + + t2_in = t2_img.float().clone() + + time3 = time.time() + + + # Degradation + if use_kspace: + b = t1_in.size(0) + t = torch.randint(0, num_timesteps, (b,), device=device).long() + mask = kspace_masks[t] + + fft, mask = apply_tofre(t2_in, mask) # TODO t2_img | t2_in + fft = fft * mask + + # Frequency Noise + if frequency_distortion: + fft_magnitude = torch.abs(fft) # 幅度 + fft_phase = torch.angle(fft) # 相位 + + sigma = distortion_sigma * torch.abs(torch.randn(1)).item() + noise = torch.randn_like(fft_magnitude) * sigma + noise_magnitude = noise * fft_magnitude * mask # + noise * (1 - mask) + fft_magnitude += noise_magnitude + + + # sigma = distortion_sigma_t / 2 * torch.abs(torch.rand(1)).item() + # noise = torch.randn_like(fft_phase) * sigma + # noise_pha = noise * fft_phase * mask # + noise * (1 - mask) + # fft_phase += noise_pha + + fft = fft_magnitude * torch.exp(1j * fft_phase) + + t2_in = apply_to_spatial(fft) + + # breakpoint() + if use_time_model: + outputs = model(t2_in, t1_img, t) # ['img_out'] + else: + outputs = model(t2_in, t1_img) # ['img_out'] + + + spatial_loss = criterion(outputs['img_out'], t2_img) + criterion(outputs['img_fre'], t2_img) + fre_loss = fft_weight * freloss(outputs['img_fre'], t2_img, mask) + fft_weight * freloss(outputs['img_out'], t2_img, mask) + loss = spatial_loss + fre_loss # + 0.01 * lpips_loss(outputs['img_out'], t2_img).mean() + + + time4 = time.time() + optimizer1.zero_grad() + loss.backward() + + if args.clip_grad == "True": + ### clip the gradients to a small range. + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01) + + optimizer1.step() + scheduler1.step() + + + progress_bar.set_description( + f"Iter {iter_num} | lr: {scheduler1.get_last_lr()[0]:.2e} | s_loss: {spatial_loss.item():.4f} | fre_loss: {fre_loss.item():.4f}" + ) + + + time5 = time.time() + + # summary + iter_num = iter_num + 1 + print_iter = 100 #if not DEBUG else 5 + if iter_num % print_iter == 0: + # logging.info('iteration %d [%.2f sec]: learning rate : %f loss : %f ' % (iter_num, time.time() - start_time, scheduler1.get_lr()[0], loss.item())) + if DEBUG: + break + + if iter_num % 20000 == 0: + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') + torch.save({'network': model.state_dict()}, save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + + if iter_num > max_iterations: + break + time1 = time.time() + + + ## ================ Evaluate ================ + logging.info(f'Epoch {epoch_num} Evaluation:') + # print() + model.eval() + + eval_result = evaluate(iter_num, model, testloader, device) + + if eval_result['PSNR'] > best_status['PSNR']: + best_status = {'NMSE': eval_result['NMSE'], 'PSNR': eval_result['PSNR'], 'SSIM': eval_result['SSIM']} + best_checkpoint_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + + torch.save({'network': model.state_dict()}, best_checkpoint_path) + print('New Best Network saved:', best_checkpoint_path) + + logging.info(f"average MSE: {eval_result['NMSE']} average PSNR: {eval_result['PSNR']} average SSIM: {eval_result['SSIM']}") + print("snapshot_path=", snapshot_path) + + if iter_num > max_iterations: + break + + print(best_status) + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(max_iterations) + '.pth') + torch.save({'network': model.state_dict()}, + save_mode_path) + print("save model to {}".format(save_mode_path)) + writer.close() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__pycache__/__init__.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..686c1bc1841dae6b1a0891b143388b982d6c9ea4 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__pycache__/lpips.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__pycache__/lpips.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ede5a84ac0ae0d9503a602265a129e1b7d03c66a Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__pycache__/lpips.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__pycache__/metric.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__pycache__/metric.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cce13a652cb1f8b4727f3556d9f969ed5fca9d03 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__pycache__/metric.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__pycache__/option.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__pycache__/option.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..864a4e7b826e5c420c34203ee97fcb6d00d05773 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__pycache__/option.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__pycache__/utils.cpython-310.pyc b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1acd54542f861d4466261664fa2849b695053e90 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/__pycache__/utils.cpython-310.pyc differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/cache/vgg.pth b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/cache/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..f57dcf5cc764d61c8a460365847fb2137ff0a62d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/cache/vgg.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 +size 7289 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/lpips.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a30b875fb4aa39ccd8419759d2f841d62bbad6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/lpips.py @@ -0,0 +1,184 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" + +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + + +from collections import namedtuple +from torchvision import models +import torch.nn as nn +import torch +from tqdm import tqdm +import requests +import os +import hashlib +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format( + name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + self.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name is not "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + model.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + input = input.float() + target = target.float() + + in0_input, in1_input = (self.scaling_layer( + input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor( + outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor( + [-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor( + [.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, + padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, + h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/metric.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..53ddb27a96bab67975beef06ca6819e628208153 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/metric.py @@ -0,0 +1,51 @@ + +import numpy as np +from skimage.metrics import peak_signal_noise_ratio, structural_similarity + +def nmse(gt, pred): + """Compute Normalized Mean Squared Error (NMSE)""" + return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2 + + +def psnr(gt, pred): + """Compute Peak Signal to Noise Ratio metric (PSNR)""" + return peak_signal_noise_ratio(gt, pred, data_range=gt.max()) + + +def ssim(gt, pred, maxval=None): + """Compute Structural Similarity Index Metric (SSIM)""" + maxval = gt.max() if maxval is None else maxval + + ssim = 0 + for slice_num in range(gt.shape[0]): + ssim = ssim + structural_similarity( + gt[slice_num], pred[slice_num], data_range=maxval + ) + + ssim = ssim / gt.shape[0] + + return ssim + + +class AverageMeter(object): + """Computes and stores the average and current value. + + Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 + """ + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self.score = [] + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + self.score.append(val) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/option.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/option.py new file mode 100644 index 0000000000000000000000000000000000000000..16536f6ffa13a304005f947aeb76a7afe62e89fa --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/option.py @@ -0,0 +1,74 @@ +import argparse + +parser = argparse.ArgumentParser(description='MRI recon') +parser.add_argument('--root_path', type=str, default='/home/xiaohan/datasets/BRATS_dataset/BRATS_2020_images/selected_images/') +parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') +parser.add_argument('--low_field_SNR', type=int, default=0, help='SNR of the simulated low-field image') +parser.add_argument('--phase', type=str, default='train', help='Name of phase') +parser.add_argument('--gpu', type=str, default='0', help='GPU to use') +parser.add_argument('--exp', type=str, default='msl_model', help='model_name') +parser.add_argument('--max_iterations', type=int, default=100000, help='maximum epoch number to train') +parser.add_argument('--batch_size', type=int, default=8, help='batch_size per gpu') +parser.add_argument('--base_lr', type=float, default=0.0002, help='maximum epoch numaber to train') +parser.add_argument('--seed', type=int, default=1337, help='random seed') +parser.add_argument('--resume', type=str, default=None, help='resume') +parser.add_argument('--relation_consistency', type=str, default='False', help='regularize the consistency of feature relation') +parser.add_argument('--clip_grad', type=str, default='True', help='clip gradient of the network parameters') +# grad_accum_steps +parser.add_argument('--grad_accum_steps', type=int, default=1, help='gradient accumulation steps') + + + +parser.add_argument('--norm', type=str, default='False', help='Norm Layer between UNet and Transformer') +parser.add_argument('--input_normalize', type=str, default='mean_std', help='choose from [min_max, mean_std, divide]') + +parser.add_argument("--dist_url", default="63654") + +parser.add_argument('--scale', type=int, default=8, + help='super resolution scale') +parser.add_argument('--base_num_every_group', type=int, default=2, + help='super resolution scale') + + +parser.add_argument('--rgb_range', type=int, default=255, + help='maximum value of RGB') +parser.add_argument('--n_colors', type=int, default=3, + help='number of color channels to use') +parser.add_argument('--augment', action='store_true', + help='use data augmentation') +parser.add_argument('--fftloss', action='store_true', + help='use data augmentation') +parser.add_argument('--fftd', action='store_true', + help='use data augmentation') +parser.add_argument('--fftd_weight', type=float, default=0.1, + help='use data augmentation') +parser.add_argument('--fft_weight', type=float, default=0.01) + + +# Model specifications +parser.add_argument('--model', type=str, default='MYNET') +parser.add_argument('--act', type=str, default='PReLU') +parser.add_argument('--data_range', type=float, default=1) +parser.add_argument('--num_channels', type=int, default=1) +parser.add_argument('--num_features', type=int, default=64) + +parser.add_argument('--n_feats', type=int, default=64, + help='number of feature maps') +parser.add_argument('--res_scale', type=float, default=0.2, + help='residual scaling') + +parser.add_argument('--MASKTYPE', type=str, default='random') # "random" or "equispaced" +parser.add_argument('--CENTER_FRACTIONS', nargs='+', type=float) +parser.add_argument('--ACCELERATIONS', nargs='+', type=int) + +parser.add_argument('--num_timesteps', type=int, default=5) +parser.add_argument('--image_size', type=int, default=240) +parser.add_argument('--distortion_sigma', type=float, default=10/255) +parser.add_argument('--DEBUG', action='store_true') +parser.add_argument('--use_kspace', action='store_true') +parser.add_argument('--use_time_model', action='store_true') + +parser.add_argument('--test_tag', default=None) +parser.add_argument('--test_sample', default="Ksample", help="Ksample | ColdDiffusion | DDPM") + +args = parser.parse_args() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/toolbox_data.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/toolbox_data.py new file mode 100644 index 0000000000000000000000000000000000000000..3671422a9557b7f9bf1f430fc08afdfd4f2d6821 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/toolbox_data.py @@ -0,0 +1,39 @@ +from torchvision import transforms +from torch.utils.data import DataLoader + + + + +def get_dataset(name, args): + + batch_size = args.batch_size * len(args.gpu.split(',')) + + + if name == 'brats': + from dataloaders.BRATS_dataloader_new import Hybrid as BratsDataset + from dataloaders.BRATS_dataloader_new import RandomPadCrop, ToTensor, RandomFlip + + + db_train = BratsDataset(split='train', MRIDOWN=args.MRIDOWN, SNR=args.low_field_SNR, + transform=transforms.Compose([RandomPadCrop(), ToTensor()]), # RandomFlip(), + base_dir=args.root_path, input_normalize = args.input_normalize, + use_kspace=args.use_kspace) + + + db_test = BratsDataset(split='test', MRIDOWN=args.MRIDOWN, SNR=args.low_field_SNR, + transform=transforms.Compose([ToTensor()]), + base_dir=args.root_path, input_normalize = args.input_normalize, + use_kspace=args.use_kspace) + + + + trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + fixtrainloader = DataLoader(db_train, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + return trainloader, fixtrainloader, testloader + + + else: + raise NotImplementedError(f'Dataset {name} is not implemented.') + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/toolbox_image.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/toolbox_image.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/toolbox_network.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/toolbox_network.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e212548bb2dfca88701636cb2f732d9daaf008 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/toolbox_network.py @@ -0,0 +1,17 @@ + +from networks.mynet import TwoBranch +from networks_time.mynet import DiffTwoBranch + + +def get_network(args, ): + if args.use_time_model: + network = DiffTwoBranch(args) + else: + network = TwoBranch(args) + + + n_parameters = sum(p.numel() for p in network.parameters() if p.requires_grad) + print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + return network + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/utils.py b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a8c9361564ec52c1dab3fb970616c8edc0893c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/utils/utils.py @@ -0,0 +1,96 @@ +import torch +from torch import nn +import numpy as np + +def cc(img1, img2): + eps = torch.finfo(torch.float32).eps + """Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.].""" + N, C, _, _ = img1.shape + img1 = img1.reshape(N, C, -1) + img2 = img2.reshape(N, C, -1) + img1 = img1 - img1.mean(dim=-1, keepdim=True) + img2 = img2 - img2.mean(dim=-1, keepdim=True) + cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum( + img1 **2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1))) + cc = torch.clamp(cc, -1., 1.) + return cc.mean() + + +def gradient_calllback(network): + for name, param in network.named_parameters(): + if param.grad is not None: + # print("Gradient of {}: {}".format(name, param.grad.abs().mean())) + + if param.grad.abs().mean() == 0: + print("Gradient of {} is 0".format(name)) + + else: + print("Gradient of {} is None".format(name)) + + +def bright(x, a,b): + # input datatype np.uint8 + x = np.array(x, dtype='float') + x = x/(b-a) - 255*a/(b-a) + x[x>255.0] = 255.0 + x[x<0.0] = 0.0 + x = x.astype(np.uint8) + return x + +def trunc(x): + # input datatype float + x[x>255.0] = 255.0 + x[x<0.0] = 0.0 + return x + + + + + +def gradient_calllback(network): + for name, param in network.named_parameters(): + if param.grad is not None: + if param.grad.abs().mean() == 0: + print("Gradient of {} is 0".format(name)) + + else: + print("Gradient of {} is None".format(name)) + +class Frequency_Loss(nn.Module): + def __init__(self): + super(Frequency_Loss, self).__init__() + self.cri = nn.L1Loss() + self.cri_sum = nn.L1Loss(reduction="sum") + + def forward(self, x, y, mask=None): + x = torch.fft.fftshift(torch.fft.fft2(x)) # rfft2 + y = torch.fft.fftshift(torch.fft.fft2(y)) + + + + # def apply_tofre(x_start, mask): + # # B, C, H, W = x_start.shape + # kspace = fftshift(fft2(x_start, norm=None, dim=(-2, -1)), dim=(-2, -1)) # Default: all dimensions + # mask = mask.to(kspace.device) + # return kspace, mask + + x_mag = torch.abs(x) + y_mag = torch.abs(y) + x_ang = torch.angle(x) + y_ang = torch.angle(y) + if isinstance(mask, type(None)): + return self.cri(x_mag,y_mag) + self.cri(x_ang, y_ang) + + k = (1 - mask.to(x.device)).detach() + # W = x.shape[-1] + # k = k[..., :W // 2 + 1] + k_total = torch.sum(k) + + x_mag = x_mag * k + y_mag = y_mag * k + x_ang = x_ang * k + y_ang = y_ang * k + + # Compute L1 loss between magnitudes + return self.cri_sum(x_mag, y_mag) / k_total + self.cri_sum(x_ang, y_ang) / k_total + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/x-bash/brats.sh b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/x-bash/brats.sh new file mode 100644 index 0000000000000000000000000000000000000000..d679b26ea80456b86e64c0510a5f1422183bd172 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/x-bash/brats.sh @@ -0,0 +1,95 @@ +# BraTS dataset, AF=4 +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion/FSMNet #-modify +cd /bask/projects/j/jiaoj-rep-learn/Hao/repo/Frequency-Diffusion/FSMNet #-modify + + +gamedrive=/media/cbtil3/74ec35fd-2452-4dcc-8d7d-3ba957e302c9 + +#4T folder: /media/cbtil3/9feaf350-913e-4def-8114-f03573c04364/hao +root_path_4x=/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee/image_100patients_4X/ +root_path_8x=/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee/image_100patients_8X/ + +root_path_4x=$gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_4X/ +root_path_8x=$gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_8X/ + + +num_timesteps=5 # 5 +max_iterations=200000 + + +# --batch_size 6 +# --base_lr 0.0001 +# mean_std | min_max + +python train_brats.py --root_path $root_path_4x\ + --gpu 1 --batch_size 2 --base_lr 1e-4 --MRIDOWN 4X --low_field_SNR 0 \ + --input_normalize mean_std --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_BraTS_4x --use_time_model --use_kspace \ + --num_timesteps $num_timesteps --max_iterations $max_iterations + + + +# BraTS dataset, AF=8 model/FSMNet_BraTS_8x_t5_kspace_time/60000 +python train_brats.py --root_path $root_path_8x \ + --gpu 1 --batch_size 2 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_BraTS_8x --use_time_model --use_kspace \ + --num_timesteps $num_timesteps --max_iterations $max_iterations + + + +# Acceleration is which effective +python train_brats.py --root_path $root_path_8x \ + --gpu 1 --batch_size 2 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std --CENTER_FRACTIONS 0.03 --ACCELERATIONS 12 \ + --exp FSMNet_BraTS_12x --use_time_model --use_kspace \ + --num_timesteps $num_timesteps --max_iterations $max_iterations + + + +# ---------------------------- +# Test +# ---------------------------- + +# Test +# BraTS dataset, AF=4 +# load weights from model/FSMNet_BraTS_4x_t5_kspace_time/best_checkpoint.pth +# test_sample # Ksample | ColdDiffusion | DDPM + +test_sample # Ksample | KsampleAR | ColdDiffusion | DDPM + + +num_timesteps=5 + +# default to load from the best_checkpoint.pth + +python test_brats.py --root_path $root_path_4x \ + --gpu 1 --batch_size 2 --base_lr 0.0001 --MRIDOWN 4X --low_field_SNR 0 \ + --num_timesteps $num_timesteps \ + --input_normalize mean_std --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_BraTS_4x --phase test --use_time_model --use_kspace \ + --test_sample Ksample + + + + +python test_brats.py --root_path $root_path_8x \ + --gpu 1 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_BraTS_8x --phase test --use_time_model --use_kspace \ + --test_sample Ksample + + + +python test_brats.py --root_path $root_path_8x \ + --gpu 1 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std --CENTER_FRACTIONS 0.03 --ACCELERATIONS 12 \ + --exp FSMNet_BraTS_12x --phase test --use_time_model --use_kspace \ + --test_sample Ksample + + + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/x-bash/fastmri.sh b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/x-bash/fastmri.sh new file mode 100644 index 0000000000000000000000000000000000000000..4948949567c48d37696f7493a6ca51b3fb3a9348 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/x-bash/fastmri.sh @@ -0,0 +1,68 @@ +# fastMRI dataset, AF=4 +# BraTS dataset, AF=4 +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion/FSMNet #-modify +cd /bask/projects/j/jiaoj-rep-learn/Hao/repo/Frequency-Diffusion/FSMNet #-modify + +data_root=/home/v-qichen3/blob/qichen_blob/MRI_recon/data/fastmri + +num_timesteps=5 # 5 +max_iterations=200000 + + +python train_fastmri.py --root_path $data_root \ + --gpu 0 --batch_size 2 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x --MRIDOWN 4X --MASKTYPE random \ + --input_normalize mean_std \ + --num_timesteps $num_timesteps --max_iterations $max_iterations \ + --image_size 320 --use_kspace --use_time_model + + + +# fastMRI dataset, AF=8 + +python train_fastmri.py --root_path $data_root \ + --gpu 1 --batch_size 2 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x --MRIDOWN 8X --MASKTYPE equispaced \ + --num_timesteps $num_timesteps --max_iterations $max_iterations \ + --image_size 320 --use_kspace --use_time_model + + +python train_fastmri.py --root_path $data_root \ + --gpu 0 --batch_size 2 --base_lr 0.0001 --CENTER_FRACTIONS 0.03 --ACCELERATIONS 12 \ + --exp FSMNet_fastmri_12x --MRIDOWN 12X --MASKTYPE equispaced \ + --num_timesteps $num_timesteps --max_iterations $max_iterations \ + --image_size 320 --use_kspace --use_time_model + + + +# Test +#fastMRI dataset, AF=4 +model_4x=model/FSMNet_fastmri_4x/iter_100000.pth + +python test_fastmri.py --root_path $data_root \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x --phase test --MRIDOWN 4X \ + --num_timesteps 5 --image_size 320 --use_kspace --use_time_model --test_sample Ksample # ColdDiffusion DDPM + + + +#fastMRI dataset, AF=8 + +python test_fastmri.py --root_path $data_root \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x --phase test --MRIDOWN 8X \ + --num_timesteps $num_timesteps --image_size 320 --use_kspace --use_time_model --test_sample Ksample # ColdDiffusion DDPM + + + + +python test_fastmri.py --root_path $data_root \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.03 --ACCELERATIONS 12 \ + --exp FSMNet_fastmri_12x --phase test --MRIDOWN 12X \ + --num_timesteps 5 --image_size 320 --use_kspace --use_time_model --test_sample Ksample # ColdDiffusion DDPM + + + +# FSMNet_fastmri_12x_t30_kspace_time + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/x-bash/m4raw.sh b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/x-bash/m4raw.sh new file mode 100644 index 0000000000000000000000000000000000000000..20748e5e58a75f72b0ab25842b8478f19a178d54 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/FSMNet/x-bash/m4raw.sh @@ -0,0 +1,68 @@ +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion/FSMNet +cd /bask/projects/j/jiaoj-rep-learn/Hao/repo/Frequency-Diffusion/FSMNet + +git pull + +data_root=/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee +data_root=/media/cbtil3/74ec35fd-2452-4dcc-8d7d-3ba957e302c9/Datasets/medical/FrequencyDiffusion + + +num_timesteps=5 # 5 +max_iterations=100000 + + + +# random | equispaced + +# save to model/FSMNet_m4raw_4x_t5_new_kspace_time/, 27.8, need 30.4? +python train_m4raw.py --root_path $data_root \ + --gpu 0 --batch_size 6 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_m4raw_4x --MRIDOWN 4X --MASKTYPE random \ + --image_size 240 --use_kspace --use_time_model \ + --max_iterations $max_iterations --num_timesteps $num_timesteps --DEBUG + + + +# m4raw dataset, AF=8 +# 240 +num_timesteps=10 +python train_m4raw.py --root_path $data_root \ + --gpu 0 --batch_size 2 --base_lr 1e-5 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_m4raw_8x --MRIDOWN 8X --MASKTYPE equispaced \ + --image_size 320 --use_kspace --use_time_model \ + --max_iterations $max_iterations --num_timesteps $num_timesteps + + + +data_root=/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee +python train_m4raw.py --root_path $data_root \ + --gpu 1 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.03 --ACCELERATIONS 12 \ + --exp FSMNet_m4raw_12x --MRIDOWN 12X --MASKTYPE equispaced \ + --num_timesteps 30 --image_size 240 --use_kspace --use_time_model + + +# ---------------- Test ---------------- + +python test_m4raw.py --root_path $data_root \ + --gpu 1 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_m4raw_4x --phase test --MASKTYPE random --MRIDOWN 4X \ + --num_timesteps $num_timesteps --image_size 240 --use_kspace --use_time_model \ + --test_sample Ksample # --test_tag no_distortion ColdDiffusion DDPM Ksample + + +#m4raw dataset, AF=8 +python test_m4raw.py --root_path $data_root \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_m4raw_8x --phase test --MASKTYPE equispaced --MRIDOWN 8X \ + --num_timesteps 10 --image_size 320 --use_kspace --use_time_model --test_sample Ksample # ColdDiffusion DDPM + + + +python test_m4raw.py --root_path $data_root \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.03 --ACCELERATIONS 12 \ + --exp FSMNet_m4raw_12x --phase test --MASKTYPE equispaced --MRIDOWN 12X \ + --num_timesteps 5 --image_size 240 --use_kspace --use_time_model --test_sample Ksample # ColdDiffusion DDPM + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/README.md b/MRI_recon/new_code/Frequency-Diffusion-main/README.md new file mode 100644 index 0000000000000000000000000000000000000000..915cfe1e79d4db2f1439101a0da25dcd1a0a461b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/README.md @@ -0,0 +1,19 @@ +# Frequency-Diffusion + + +Run Knee FastMRI dataset, the test code is also in the end of the file +```bash + +cd FSMNet +bash bash/fastmri.sh + +``` + +Run Brain m4raw dataset, the test code is also in the end of the file +```bash + +cd FSMNet +bash bash/m4raw.sh + +``` + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/bash/adden/brain.sh b/MRI_recon/new_code/Frequency-Diffusion-main/bash/adden/brain.sh new file mode 100644 index 0000000000000000000000000000000000000000..c0e05b7a031a05ea07cebbb21270234cb9ffa5f7 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/bash/adden/brain.sh @@ -0,0 +1,82 @@ +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion +# cd STEP1.AutoencoderModel2D +#git stash +git pull + +datapath=/home/hao/data/medical/Brain/ +# /gamedrive/Datasets/medical/Brain/ + +dataset=Brain +domain=BraTS-GLI-T1C # T1C +aux_modality=T1N # T1C, T1N, T2W, T2F +num_channels=1 + + +# T1: T1-weighted MRI; T1c: gadolinium-contrast-enhanced T1-weighted MRI; + + +diffusion_type=twobranch_kspace # Easy NaN + + +time_step=30 +image_size=240 #128 +sampling_routine=x0_step_down_fre # x0_step_down_fre # x0_step_down_fre # default | x0_step_down | x0_step_down_fre +loss_type=l1 # l2 1 # l2 | l1 | l2_l1, l1 is better + + +tag=new_norm #add_blur_transformer # x0_step_down | x0_step_down_fre + +deviceid=1 # specify the GPU ids +# fre_before_attn + l1 +train_bs=2 # 4 | 32 | 24 | 36 + + +save_folder=./results/${diffusion_type}_${sampling_routine} + + +datapath=/home/hao/data/medical/Brain/ +datapath=/gamedrive/Datasets/medical/Brain/brats/Processed/ + + +mode=train + +# Train +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --domain $domain --aux_modality $aux_modality \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --loss_type $loss_type --debug --mode $mode # --debug, --discrete + +# FSM Brain +/gamedrive/Datasets/medical/FrequencyDiffusion/image_100patients_4X + +BraTS20_Training_099_99_t1.png +BraTS20_Training_099_99_t2.png + +mode=test +checkpoint=results/52_twobranch_kspace_x0_step_down_fre_new_norm/model.pt + + +deviceid=1 # deviceid=1 + + +# test +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --domain $domain --aux_modality $aux_modality \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --load_path $checkpoint --debug --mode $mode # --debug, --discrete + + + +# Knee +/gamedrive/Datasets/medical/Knee/fastMRI/ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/bash/adden/fsm_brain.sh b/MRI_recon/new_code/Frequency-Diffusion-main/bash/adden/fsm_brain.sh new file mode 100644 index 0000000000000000000000000000000000000000..6382b1e4b75426e366d3c62cd25721c247a0d854 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/bash/adden/fsm_brain.sh @@ -0,0 +1,69 @@ +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion +# cd STEP1.AutoencoderModel2D +#git stash +git pull + + +dataset=fsm_brain +num_channels=1 +diffusion_type=twobranch_kspace + + +datapath=/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_4X + + +time_step=25 +image_size=240 #128 +sampling_routine=x0_step_down_fre # default | x0_step_down | x0_step_down_fre +loss_type=l1 # l2 | l1 | l2_l1, l1 is better + + +tag=adden_brain #add_blur_transformer # x0_step_down | x0_step_down_fre + +deviceid=1 # specify the GPU ids +# fre_before_attn + l1 +train_bs=4 # 4 | 32 | 24 | 36 +accelerate_factor=8 + +save_folder=./results/${diffusion_type}_${sampling_routine} + + +normalizer="mean_std" +mode=train + +# Train +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --loss_type $loss_type --accelerate_factor $accelerate_factor\ + --mode $mode --normalizer $normalizer --debug # --debug, --discrete + +# FSM Brain + +mode=test +checkpoint=results/52_twobranch_kspace_x0_step_down_fre_new_norm/model.pt + + +deviceid=1 # deviceid=1 + + +# test +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --domain $domain --aux_modality $aux_modality \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --load_path $checkpoint --debug --mode $mode # --debug, --discrete + + + +# Knee +/gamedrive/Datasets/medical/Knee/fastMRI/ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/bash/adden/knee.sh b/MRI_recon/new_code/Frequency-Diffusion-main/bash/adden/knee.sh new file mode 100644 index 0000000000000000000000000000000000000000..27f30c851971f46d1a241a63da7c9583e0ccd5a7 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/bash/adden/knee.sh @@ -0,0 +1,68 @@ +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion +# cd STEP1.AutoencoderModel2D +#git stash +git pull + + +dataset=fsm_brain +num_channels=1 +diffusion_type=twobranch_kspace + + +datapath=/gamedrive/Datasets/medical/FrequencyDiffusion/singlecoil_train_selected + + +time_step=25 +image_size=320 #128 +sampling_routine=x0_step_down_fre # default | x0_step_down | x0_step_down_fre +loss_type=l1 # l2 | l1 | l2_l1, l1 is better + + +tag=adden_brain #add_blur_transformer # x0_step_down | x0_step_down_fre + +deviceid=1 # specify the GPU ids +train_bs=4 # 4 | 32 | 24 | 36 +accelerate_factor=8 + +save_folder=./results/${diffusion_type}_${sampling_routine} + + +normalizer="mean_std" +mode=train + +# Train +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --loss_type $loss_type --accelerate_factor $accelerate_factor\ + --mode $mode --normalizer $normalizer --debug # --debug, --discrete + +# FSM Brain + +mode=test +checkpoint=results/52_twobranch_kspace_x0_step_down_fre_new_norm/model.pt + + +deviceid=1 # deviceid=1 + + +# test +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --domain $domain --aux_modality $aux_modality \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --load_path $checkpoint --debug --mode $mode # --debug, --discrete + + + +# Knee +/gamedrive/Datasets/medical/Knee/fastMRI/ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/bash/bask/brain.sh b/MRI_recon/new_code/Frequency-Diffusion-main/bash/bask/brain.sh new file mode 100644 index 0000000000000000000000000000000000000000..a085b08e428a3b1493ec5f890cc13a6d58a3a665 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/bash/bask/brain.sh @@ -0,0 +1,82 @@ +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion +# cd STEP1.AutoencoderModel2D +#git stash +git pull + +datapath=/home/hao/data/medical/Brain/ +# /gamedrive/Datasets/medical/Brain/ + +dataset=Brain +domain=BraTS-GLI-T1C # T1C +aux_modality=T1N # T1C, T1N, T2W, T2F +num_channels=1 + + +# T1: T1-weighted MRI; T1c: gadolinium-contrast-enhanced T1-weighted MRI; + + +diffusion_type=twobranch_kspace # Easy NaN + + +time_step=30 +image_size=480 #128 +sampling_routine=x0_step_down_fre # x0_step_down_fre # x0_step_down_fre # default | x0_step_down | x0_step_down_fre +loss_type=l1 # l2 1 # l2 | l1 | l2_l1, l1 is better + + +tag=new_norm #add_blur_transformer # x0_step_down | x0_step_down_fre + +deviceid=1 # specify the GPU ids +# fre_before_attn + l1 +train_bs=2 # 4 | 32 | 24 | 36 + + +save_folder=./results/${diffusion_type}_${sampling_routine} + + +datapath=/home/hao/data/medical/Brain/ +datapath=/gamedrive/Datasets/medical/Brain/brats/Processed/ + + +mode=train + +# Train +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --domain $domain --aux_modality $aux_modality \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --loss_type $loss_type --debug --mode $mode # --debug, --discrete + +# FSM Brain +/gamedrive/Datasets/medical/FrequencyDiffusion/image_100patients_4X + +BraTS20_Training_099_99_t1.png +BraTS20_Training_099_99_t2.png + +mode=test +checkpoint=results/52_twobranch_kspace_x0_step_down_fre_new_norm/model.pt + + +deviceid=1 # deviceid=1 + + +# test +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --domain $domain --aux_modality $aux_modality \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --load_path $checkpoint --debug --mode $mode # --debug, --discrete + + + +# Knee +/gamedrive/Datasets/medical/Knee/fastMRI/ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/bash/bask/fsm_brain.sh b/MRI_recon/new_code/Frequency-Diffusion-main/bash/bask/fsm_brain.sh new file mode 100644 index 0000000000000000000000000000000000000000..f285065b12014ca64037683ec4afd95db79c85a8 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/bash/bask/fsm_brain.sh @@ -0,0 +1,68 @@ +mamba activate diffmri +cd /bask/projects/j/jiaoj-rep-learn/Hao/repo/Frequency-Diffusion +# cd STEP1.AutoencoderModel2D +#git stash +git pull + + +dataset=fsm_brain +num_channels=1 +diffusion_type=twobranch_kspace # Easy NaN +datapath=/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee/image_100patients_4X/ + + +time_step=30 +image_size=240 #128 +sampling_routine=x0_step_down_fre # default | x0_step_down | x0_step_down_fre +loss_type=l1 # l2 | l1 | l2_l1, l1 is better + + +tag=fsm_brain #add_blur_transformer # x0_step_down | x0_step_down_fre + +deviceid=0 # specify the GPU ids +# fre_before_attn + l1 +train_bs=4 # 4 | 32 | 24 | 36 + + +save_folder=./results/${diffusion_type}_${sampling_routine} +normalizer="mean_std" + +mode=train +example_frequency_img="/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee/image_100patients_4X/BraTS20_Training_036_90_t2_4X_undermri.png" # some example img +example_frequency_img="" + +# Train +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --loss_type $loss_type --normalizer $normalizer \ + --example_frequency_img $example_frequency_img --debug --mode $mode # --debug, --discrete + +# FSM Brain + +mode=test +checkpoint=results/52_twobranch_kspace_x0_step_down_fre_new_norm/model.pt + + +deviceid=1 # deviceid=1 + + +# test +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --domain $domain --aux_modality $aux_modality \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --load_path $checkpoint --debug --mode $mode # --debug, --discrete + + + +# Knee +/gamedrive/Datasets/medical/Knee/fastMRI/ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/bash/bask/knee.sh b/MRI_recon/new_code/Frequency-Diffusion-main/bash/bask/knee.sh new file mode 100644 index 0000000000000000000000000000000000000000..fd3c61236ab8da5543f76676e67c2a921e6f1d3d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/bash/bask/knee.sh @@ -0,0 +1,64 @@ +mamba activate diffmri +cd /bask/projects/j/jiaoj-rep-learn/Hao/repo/Frequency-Diffusion +# cd STEP1.AutoencoderModel2D +#git stash +git pull + + +dataset=fsm_knee +num_channels=1 +diffusion_type=twobranch_kspace # Easy NaN +datapath=/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee/singlecoil_train_selected + + +time_step=30 +image_size=320 #320 +sampling_routine=x0_step_down_fre # default | x0_step_down | x0_step_down_fre +loss_type=l1 # l2 | l1 | l2_l1, l1 is better + + +tag=fsm_knee #add_blur_transformer # x0_step_down | x0_step_down_fre + +deviceid=0 # specify the GPU ids +train_bs=4 # 4 | 32 | 24 | 36 + +normalizer="mean_std" +save_folder=./results/${diffusion_type}_${sampling_routine} + + +mode=train + +# Train +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --loss_type $loss_type --normalizer $normalizer --debug --mode $mode # --debug, --discrete + + + +mode=test +checkpoint=results/71_twobranch_kspace_x0_step_down_fre_new_loss/model.pt + + +#deviceid=1 # deviceid=1 + + +# test +python main.py --time_steps $time_step --train_steps 700000 \ + --save_folder $save_folder --tag $tag \ + --data_path $datapath --dataset $dataset \ + --sampling_routine $sampling_routine \ + --remove_time_embed --residual --image_size $image_size \ + --diffusion_type $diffusion_type --train_bs $train_bs \ + --num_channels $num_channels --deviceid $deviceid \ + --kernel_std 0.15 --load_path $checkpoint --debug --mode $mode + # --debug, --discrete + + + +# Knee +/gamedrive/Datasets/medical/Knee/fastMRI/ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/BRATS_dataloader.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/BRATS_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..bc433096d6058d9c5e7a259e56a6af2da385737c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/BRATS_dataloader.py @@ -0,0 +1,419 @@ +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import cv2 +import os +import torch +from torch.utils.data import Dataset +from torchvision import transforms + + +from dataset.m4_utils.transform_albu import get_albu_transforms + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', image_size=(128,128), MRIDOWN='4X', \ + SNR=15, transform=None, input_normalize=None, debug=False): + + super().__init__() + self._base_dir = base_dir + '/' + # self._MRIDOWN = MRIDOWN + + + self.transforms = get_albu_transforms(split, image_size) + self.kspace_refine = "False" + self._MRIDOWN = "4X" + + self.im_ids = [] + self.t2_images = [] + self.t1_undermri_images, self.t2_undermri_images = [], [] + self.t1_krecon_images, self.t2_krecon_images = [], [] + self.splits_path = base_dir.replace("image_100patients_4X", "cv_splits_100patients") + + if split=='train': + self.train_file = self.splits_path + '/train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + '/test_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + if debug: + self.t1_images = self.t1_images[:10] + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + + if SNR == 0: + t1_under_path = image_path + + if self.kspace_refine == "False": + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + elif self.kspace_refine == "True": + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_krecon') + + if self.kspace_refine == "False": + t1_krecon_path = image_path + t2_krecon_path = image_path + + # if SNR == 0: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_undermri') + t1_under_path = image_path + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + + else: + t1_under_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB') + + t1_krecon_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_krecon_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB') + + + self.t2_images.append(t2_path) + self.t1_undermri_images.append(t1_under_path) + self.t2_undermri_images.append(t2_under_path) + + self.t1_krecon_images.append(t1_krecon_path) + self.t2_krecon_images.append(t2_krecon_path) + + self.transform = transform + self.input_normalize = input_normalize + + assert (len(self.t1_images) == len(self.t2_images)) + assert (len(self.t1_images) == len(self.t1_undermri_images)) + assert (len(self.t1_images) == len(self.t2_undermri_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + def update_chunk(self): + pass + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + t1 = np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0 + t2 = np.array(Image.open(self._base_dir + self.t2_images[index])) / 255.0 + + t1_in = np.array(Image.open(self._base_dir + self.t1_undermri_images[index]))/255.0 + t1_krecon = np.array(Image.open(self._base_dir + self.t1_krecon_images[index]))/255.0 + t2_in = np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0 + t2_krecon = np.array(Image.open(self._base_dir + self.t2_krecon_images[index]))/255.0 + + t1 = np.asarray(t1, np.float32) + t2 = np.asarray(t2, np.float32) + + if self.input_normalize == "mean_std": + t1_in, t1_mean, t1_std = normalize_instance(t1_in, eps=1e-11) + t1 = normalize(t1, t1_mean, t1_std, eps=1e-11) + t2_in, t2_mean, t2_std = normalize_instance(t2_in, eps=1e-11) + t2 = normalize(t2, t2_mean, t2_std, eps=1e-11) + + t1_krecon = normalize(t1_krecon, t1_mean, t1_std, eps=1e-11) + t2_krecon = normalize(t2_krecon, t2_mean, t2_std, eps=1e-11) + + ### clamp input to ensure training stability. + t1_in = np.clip(t1_in, -6, 6) + t1 = np.clip(t1, -6, 6) + t2_in = np.clip(t2_in, -6, 6) + t2 = np.clip(t2, -6, 6) + + t1_krecon = np.clip(t1_krecon, -6, 6) + t2_krecon = np.clip(t2_krecon, -6, 6) + + sample_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + elif self.input_normalize == "min_max": + t1_in = (t1_in - t1_in.min())/(t1_in.max() - t1_in.min()) + t1 = (t1 - t1.min())/(t1.max() - t1.min()) + t2_in = (t2_in - t2_in.min())/(t2_in.max() - t2_in.min()) + t2 = (t2 - t2.min())/(t2.max() - t2.min()) + sample_stats = 0 + + t1_mean = 0 + t1_std = 1 + t2_mean = 0 + t2_std = 1 + + elif self.input_normalize == "divide": + sample_stats = 0 + + + sample = {'image_in': t1_in, + 'image': t1, + 'image_krecon': t1_krecon, + 'target_in': t2_in, + 'target': t2, + 'target_krecon': t2_krecon} + + # print("images shape:", t1_in.shape, t1.shape, t2_in.shape, t2.shape) + + # return sample, sample_stats + + t1_main = False + # t1 support t2 accelerate + + if t1_main: + img = t1 + img_mean = t1_mean + img_std = t1_std + + aux = t2 + aux_mean = t2_mean + aux_std = t2_std + else: + img = t2 + img_mean = t2_mean + img_std = t2_std + + aux = t1 + aux_mean = t1_mean + aux_std = t1_std + + + # print("img shape:", img.shape, aux.shape, img.max()) # 240, 240 + + data_dict = self.transforms(image=img, image2=aux) + img = data_dict['image'] + aux = data_dict['image2'] + + img = np.asarray(np.expand_dims(img, axis=0), np.float32) + aux = np.asarray(np.expand_dims(aux, axis=0), np.float32) + + data = {"img": img, "aux": aux, + "img_mean": np.float32(img_mean), "img_std": np.float32(img_std), + "aux_mean": np.float32(aux_mean), "aux_std": np.float32(aux_std), + } + + return data + + + +def add_gaussian_noise(img, mean=0, std=1): + noise = std * torch.randn_like(img) + mean + noisy_img = img + noise + return torch.clamp(noisy_img, 0, 1) + + + +class AddNoise(object): + def __call__(self, sample): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + add_gauss_noise = transforms.GaussianBlur(kernel_size=5) + add_poiss_noise = transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)) + + add_noise = transforms.RandomApply([add_gauss_noise, add_poiss_noise], p=0.5) + + img_in = add_noise(img_in) + target_in = add_noise(target_in) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + + return sample + + + + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 256, 256 + crop_size = 240 + pad_size = (256-240)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_krecon = sample['image_krecon'] + target_krecon = sample['target_krecon'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + img_krecon = np.pad(img_krecon, pad_size, mode='reflect') + target_krecon = np.pad(target_krecon, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + # print("img_in:", img_in.shape) + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + img_krecon = img_krecon[ww:ww+crop_size, hh:hh+crop_size] + target_krecon = target_krecon[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'image_krecon': img_krecon, \ + 'target_in': target_in, 'target': target, 'target_krecon': target_krecon} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + + +class RandomFlip(object): + def __call__(self, sample): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + # horizontal flip + if random.random() < 0.5: + img_in = cv2.flip(img_in, 1) + img = cv2.flip(img, 1) + target_in = cv2.flip(target_in, 1) + target = cv2.flip(target, 1) + + # vertical flip + if random.random() < 0.5: + img_in = cv2.flip(img_in, 0) + img = cv2.flip(img, 0) + target_in = cv2.flip(target_in, 0) + target = cv2.flip(target, 0) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + + + +class RandomRotate(object): + def __call__(self, sample, center=None, scale=1.0): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + degrees = [0, 90, 180, 270] + angle = random.choice(degrees) + + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + + img_in = cv2.warpAffine(img_in, matrix, (w, h)) + img = cv2.warpAffine(img, matrix, (w, h)) + target_in = cv2.warpAffine(target_in, matrix, (w, h)) + target = cv2.warpAffine(target, matrix, (w, h)) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + + image_krecon = sample['image_krecon'][:, :, None].transpose((2, 0, 1)) + target_krecon = sample['target_krecon'][:, :, None].transpose((2, 0, 1)) + + # print("img_in before_numpy range:", img_in.max(), img_in.min()) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + image_krecon = torch.from_numpy(image_krecon).float() + target_krecon = torch.from_numpy(target_krecon).float() + # print("img_in range:", img_in.max(), img_in.min()) + + return {'image_in': img_in, + 'image': img, + 'target_in': target_in, + 'target': target, + 'image_krecon': image_krecon, + 'target_krecon': target_krecon} diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf6736b58fe7f6c85492b3f8a78ad34bb1c49520 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/__init__.py @@ -0,0 +1,3 @@ +from .brain import BrainDataset +from .celeb import Dataset, Dataset_Aug1 + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/basic.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..da03388bb83aa0a49793ebc3cd7be4302f0d4fc5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/basic.py @@ -0,0 +1,460 @@ + +# Dataloader for abdominal images +import glob +import numpy as np +from .m4_utils import niftiio as nio +from .m4_utils import transform_utils as trans +from .m4_utils.abd_dataset_utils import get_normalize_op +from .m4_utils.transform_albu import get_albu_transforms, get_resize_transforms +import copy +import random, cv2, os +import torch.utils.data as torch_data +import math +import itertools +from pdb import set_trace +from multiprocessing import Process +import albumentations as A +from tqdm import tqdm + + +def get_basedir(data_dir): + return os.path.join(data_dir, "Abdominal") + + +class BasicDataset(torch_data.Dataset): + def __init__(self, fineSize, mode, transforms, base_dir, domains: list, aux_modality, pseudo = False, + idx_pct = [0.7, 0.1, 0.2], tile_z_dim = 3, extern_norm_fn = None, + LABEL_NAME=["bg", "fore"], debug=False, nclass=4, num_channels=3, + filter_non_labeled=False, use_diff_axis_view=False, chunksize=200): + """ + Args: + mode: 'train', 'val', 'test', 'test_all' + transforms: naive data augmentations used by default. Photometric transformations slightly better than those configured by Zhang et al. (bigaug) + idx_pct: train-val-test split for source domain + extern_norm_fn: feeding data normalization functions from external, only concerns CT-MR cross domain scenario + """ + super(BasicDataset, self).__init__() + + self.fineSize = fineSize + self.transforms = transforms + self.nclass = nclass + self.debug = debug + self.is_train = True if mode == 'train' else False + self.phase = mode + self.domains = domains + self.num_channels = num_channels + + # Modality + self.main_modality = domains[-1].split("-")[-1] + self.aux_modality = aux_modality.upper() + + print(f"=== Donmain: {domains}, Main modality: {self.main_modality}, Aux modality: {self.aux_modality}") + + self.pseudo = pseudo + self.all_label_names = LABEL_NAME + self.nclass = len(LABEL_NAME) + self.tile_z_dim = tile_z_dim + self._base_dir = base_dir + self.idx_pct = idx_pct + # self.albu_transform = get_albu_transforms((fineSize, fineSize)) + self.test_resizer = get_resize_transforms(fineSize) + self.fake_interpolate = True # True + self.use_diff_axis_view = use_diff_axis_view + self.filter_non_labeled = filter_non_labeled + self.input_window = 1 + + self.resizer = A.Compose([ + A.Resize(fineSize[0], fineSize[1], interpolation=cv2.INTER_NEAREST) + ], p=1.0, additional_targets={'image2': 'image', "mask2": "mask"}) + + self.img_pids = {} + for _domain in self.domains: # load file names + if "BraTS" in _domain: + self.img_pids[_domain] = sorted([ fid.split("-")[-2] for fid in + glob.glob(self._base_dir + "/" + _domain + "/img/*.nii.gz") ], + key = lambda x: int(x)) + + else: + self.img_pids[_domain] = sorted([fid.split("_")[-1].split(".nii.gz")[0] for fid in + glob.glob(self._base_dir + "/" + _domain + "/img/*.nii.gz")], + key=lambda x: int(x)) + + self.scan_ids = self.__get_scanids(mode, idx_pct) # train val test split in terms of patient ids + try: + print(f'For {self.phase} on {[_dm for _dm in self.domains]} using scan ids len = ' + \ + f'{[len(self.scan_ids[_dm]) for _dm in self.scan_ids.keys()]}') + except: + print("Errors of self.scan_ids") + print(self.scan_ids) + + + self.info_by_scan = None + self.sample_list = self.__search_samples(self.scan_ids) # image files names according to self.scan_ids + if self.is_train: + + self.pid_curr_load = self.scan_ids + elif mode == 'val': + self.pid_curr_load = self.scan_ids + elif mode == 'test': # Source domain test + self.pid_curr_load = self.scan_ids + elif mode == 'test_all': + # Choose this when being used as a target domain testing set. Liu et al. + self.pid_curr_load = self.scan_ids + + if extern_norm_fn is None: + self.normalize_op = get_normalize_op(self.domains[0], [itm['img_fid'] for _, itm in + self.sample_list[self.domains[0]].items() ]) + print(f'{self.phase}_{self.domains[0]}: Using fold data statistics for normalization') + + else: + # assert len(self.domains) == 1, 'for now we only support one normalization function for the entire set' + self.normalize_op = extern_norm_fn + + + # load to memory + # self.sample_list All + self.actual_dataset = None + self.chunksize = chunksize if not debug else 3 + + self.chunk_id = 0 + self.chunk_pool, self.current_chunk = {}, {} + for _domain, item in self.sample_list.items(): + self.chunk_pool[_domain] = list(item.keys()) + + chunk, status = self.next_chunk(self.sample_list) + self.actual_dataset = self.__read_dataset(chunk, status) + self.size = len(self.actual_dataset) # 2D + + print("----- Set up dataset for", self.phase, "with chunksize=", chunksize) + + def update_chunk(self): + chunk, status = self.next_chunk(self.sample_list) + self.actual_dataset = self.__read_dataset(chunk, status) + + def __get_scanids(self, mode, idx_pct): + """ + index by domains given that we might need to load multi-domain data + idx_pct: [0.7 0.1 0.2] for train val test. with order te val tr + """ + tr_ids = {} + val_ids = {} + te_ids = {} + te_all_ids = {} + + for _domain in self.domains: + dset_size = len(self.img_pids[_domain]) + tr_size = round(dset_size * idx_pct[0]) + val_size = math.floor(dset_size * idx_pct[1]) + te_size = dset_size - tr_size - val_size + # print('te_size = ', te_size) + + te_ids[_domain] = self.img_pids[_domain][: te_size] + val_ids[_domain] = self.img_pids[_domain][te_size: te_size + val_size] + tr_ids[_domain] = self.img_pids[_domain][te_size + val_size: ] + te_all_ids[_domain] = list(itertools.chain(tr_ids[_domain], te_ids[_domain], val_ids[_domain] )) + + print(" self.phase = ", self.phase) + if self.phase == 'train': + return tr_ids + elif self.phase == 'val': + return val_ids + elif self.phase == 'test': + return te_ids + elif self.phase == 'test_all': + return te_all_ids + + def __search_samples(self, scan_ids): + """search for filenames for images and masks + """ + out_list = {} + for _domain, id_list in scan_ids.items(): + domain_dir = os.path.join(self._base_dir, _domain) + print("=== reading domains from:", domain_dir) + out_list[_domain] = {} + for curr_id in id_list: + curr_dict = {} + if "BraTS" in _domain: + + _img_fid = os.path.join(domain_dir, 'img', f'{_domain[:-4]}-{curr_id}-000.nii.gz') + if not self.pseudo: + _lb_fid = os.path.join(domain_dir, 'seg', f'{_domain[:-4]}-{curr_id}-000.nii.gz') + else: + _lb_fid = os.path.join(domain_dir, 'seg', f'{_domain[:-4]}-{curr_id}-000.nii.gz.npy') # npy + + _aux_fid = _img_fid.replace(self.main_modality, self.aux_modality) + + + + curr_dict["img_fid"] = _img_fid + curr_dict["lbs_fid"] = _lb_fid + curr_dict["aux_fid"] = _aux_fid + out_list[_domain][str(curr_id)] = curr_dict + + print("=== search sample num:", len(out_list)) + return out_list + + + def filter_with_label(self, img, lb, aux): + # H, W, C, filter zero + if self.phase == "train": + + filter = np.any(np.any(img, axis=0), axis=0) + img, lb, aux = img[..., filter], lb[..., filter], aux[..., filter] + + + if self.filter_non_labeled: + + if self.dataset_key == "knee": + filter2 = np.any(np.any(lb == 2, axis=0), axis=0) + filter4 = np.any(np.any(lb == 4, axis=0), axis=0) + filter = filter2 + filter4 + else: + filter = np.any(np.any(lb, axis=0), axis=0) + + filter_right = np.roll(filter, 3) + filter_left = np.roll(filter, -3) + filter = filter + filter_right + filter_left + filter = filter > 0 + + # HWC + img, lb, aux = img[..., filter], lb[..., filter], aux[..., filter] + + return img, lb, aux + + def __read_dataset(self, chunk, status): + """ + Read the dataset into memory + """ + + out_list = [] + self.info_by_scan = {} # meta data of each scan + glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset + for _domain, _curr_chunk in tqdm(chunk.items()): # .items() + domain_ids = 0 + if status[_domain] != 3: + print(f"==== UPDATE dataset for: {_domain} w/ status = {status[_domain]}") + + for scan_id in _curr_chunk: + domain_ids += 1 + if domain_ids > self.chunksize: + print(f"=== UPDATE finished") + break + + itm = self.sample_list[_domain][scan_id] + if scan_id not in self.pid_curr_load[_domain]: + continue + + # Keep the original dataset + if (status[_domain] == 0) or (status[_domain] == 2 and domain_ids <= self.chunksize // 2): + size = self.actual_dataset[glb_idx]['size'] + out_list.extend(self.actual_dataset[glb_idx: glb_idx + size]) # Original dataset + glb_idx += size + continue + + if (status[_domain] == 1 and domain_ids > self.chunksize // 2): + try: + size = self.actual_dataset[glb_idx]['size'] + out_list.extend(self.actual_dataset[glb_idx: glb_idx + size]) # Original dataset + glb_idx += size + continue + except: + print(f"=== Warning (domain_ids={domain_ids}) getting glb_idx={glb_idx} from actual_dataset length={len(self.actual_dataset)}") + # print(self.actual_dataset) + + img, _info = nio.read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out + self.info_by_scan[_domain + '_' + scan_id] = _info + + img_original = np.float32(img) + img = img_original.copy() + + aux = nio.read_nii_bysitk(itm["aux_fid"]) + aux_original = np.float32(aux) + aux = aux_original.copy() + + + # img, self.mean, self.std = self.normalize_op(img) + _, mean, std = self.normalize_op(img) + _, aux_mean, aux_std = self.normalize_op(aux) + + if not self.pseudo: + lb = nio.read_nii_bysitk(itm["lbs_fid"]) + else: + uncertainty_thr = 0.05 # 0.05 + lb_cache = np.load(itm["lbs_fid"], allow_pickle=True).item() + lb = lb_cache['pseudo'].cpu().numpy() # "pseudo": curr_pred, "score":curr_score , "uncertainty" + uncertainty = lb_cache['uncertainty'].cpu().numpy() # Z, C, H, W + uncertainty = np.float32(uncertainty) + + new_lb = np.zeros_like(lb) + for cls in range(self.nclass - 1): + un_mask = (uncertainty[:, cls+1] < uncertainty_thr ) * (cls+1) + new_lb[lb == (cls+1)] = un_mask[lb == (cls+1)] + + lb = new_lb + + lb_original = np.float32(lb) + lb = lb_original.copy() + + # -> H, W, C + img, lb, aux = map(lambda arr: np.transpose(arr, (1, 2, 0)), [img, lb, aux]) + assert img.shape[-1] == lb.shape[-1], f"ASSERT {img.shape} = {lb.shape}" + + # Resize: + if img.shape[1] != self.fineSize[1]: + # H, W, C + res = self.resizer(image=img, mask=lb, image2=aux) + img, lb, aux = res['image'], res['mask'], res['image2'] + + prt_cache = f" {_domain} stat ({domain_ids}/{len(_curr_chunk)}): shape={img.shape}, max={img.max()}, min={img.min()}" + + # Filter vacant slices + if self.phase == "train": + filter = np.any(np.any(img, axis=0), axis=0) + img, lb, aux = img[..., filter], lb[..., filter], aux[..., filter] + + img, lb, aux = self.filter_with_label(img, lb, aux) + + out_list, glb_idx = self.add_to_list(glb_idx, out_list, img, lb, + aux, mean, aux_mean, aux_std, std, _domain, + scan_id, itm["img_fid"]) + + if (domain_ids) % (len(_curr_chunk) // 2) == 0: + print(prt_cache + f", filtered shape={img.shape}, mask max={lb.max()}") + + + # Add various axis view !!! + if self.phase == "train" and self.use_diff_axis_view: + # C, W, H + img, lb, aux = img_original, lb_original, aux_original + # Resize: + if img.shape[1] != self.fineSize[1]: + res = self.resizer(image=img, mask=lb, image2=aux) # assume H, W, (C)<- + img, lb, aux = res['image'], res['mask'], res['image2'] + + img, lb, aux = self.filter_with_label(img, lb, aux) + + out_list, glb_idx = self.add_to_list(glb_idx, out_list, img, lb, + aux, mean, aux_mean, aux_std, std, _domain, + scan_id, itm["img_fid"]) + + del img, lb, aux, img_original, lb_original, aux_original + + del self.actual_dataset + return out_list + + def next_chunk(self, all_samples): + # 0 No update, 1 First half, 2 Second half, 3 All updates Chunk + status = {} + self.last_chunk = copy.deepcopy(self.current_chunk) + for _domain, _sample_list in tqdm(all_samples.items()): + # Default value + status[_domain] = 3 + + # Put all in - validation or small dataset + if ((not self.is_train) or len(_sample_list) < self.chunksize) and not self.debug: + self.current_chunk[_domain] = _sample_list + if _domain not in self.last_chunk: + status[_domain] = 3 # all + else: + status[_domain] = 0 # not updates + print("=== Put all data in for", _domain) + continue + + # chunksize + random.shuffle(self.chunk_pool[_domain]) + if _domain not in self.last_chunk: + status[_domain] = 3 + self.current_chunk[_domain] = self.chunk_pool[_domain][:self.chunksize] + self.chunk_pool[_domain] = self.chunk_pool[_domain][self.chunksize:] + + else: + status[_domain] = self.chunk_id//2 + 1 # 1, 2 + candidate = self.chunk_pool[_domain][:self.chunksize//2] + self.chunk_pool[_domain] = self.chunk_pool[_domain][self.chunksize //2:] + if status[_domain] == 1: + self.current_chunk[_domain][:self.chunksize // 2] = candidate + else: + self.current_chunk[_domain][self.chunksize // 2:] = candidate + + if _domain in self.last_chunk: + self.chunk_pool[_domain] = self.chunk_pool[_domain] + self.last_chunk[_domain] + + self.chunk_id += 1 + + return self.current_chunk, status + + + def add_to_list(self, glb_idx, out_list, img, lb, aux, mean, std, aux_mean, aux_std, _domain, scan_id, file_id): + # now start writing everthing in + c = 3 + + for ii in range(img.shape[-1]): + is_end = False + is_start = False + if ii == 0: + is_start = True + # write the beginning frame + if self.input_window == 3: + _img = img[..., 0: c].copy() + _img[..., 1] = _img[..., 0] + elif self.input_window == 1: + _img = img[..., 0: 0 + 1].copy() + + + elif ii < img.shape[-1] - 1: + if self.input_window == 3: + _img = img[..., ii -1: ii + 2].copy() + elif self.input_window == 1: + _img = img[..., ii: ii + 1].copy() + + else: + is_end = True + if self.input_window == 3: + _img = img[..., ii-2: ii + 1].copy() + _img[..., 0] = _img[..., 1] + elif self.input_window == 1: + _img = img[..., ii: ii+ 1].copy() + + _lb = lb[..., ii: ii + 1] + _aux = aux[..., ii: ii + 1] + + out_list.append( + {"img": _img, "lb":_lb, "aux":_aux, "size": img.shape[-1], + "mean":mean, "std":std, + "aux_mean": aux_mean, "aux_std": aux_std, + "is_start": is_start, "is_end": is_end, + "domain": _domain, "nframe": img.shape[-1], + "scan_id": _domain + "_" + scan_id, + "pid": scan_id, "file_id": file_id, "z_id":ii}) + glb_idx += 1 + + return out_list, glb_idx + + + def get_patch_from_img(self, img_H, img_L, img_L2, crop_size=[320, 320], zslice_dim=2): + # -------------------------------- + # randomly crop the patch + # -------------------------------- + + H, W, _ = img_H.shape + rnd_h = random.randint(0, max(0, H - crop_size[0])) + rnd_w = random.randint(0, max(0, W - crop_size[1])) + + # image = torch.index_select(image, 0, torch.tensor([1])) + if zslice_dim == 2: + patch_H = img_H[rnd_h:rnd_h + crop_size[0], rnd_w:rnd_w + crop_size[1], :] + patch_L = img_L[rnd_h:rnd_h + crop_size[0], rnd_w:rnd_w + crop_size[1], :] + patch_L2 = img_L2[rnd_h:rnd_h + crop_size[0], rnd_w:rnd_w + crop_size[1], :] + elif zslice_dim == 0: + patch_H = img_H[:, rnd_h:rnd_h + crop_size[0], rnd_w:rnd_w + crop_size[1]] + patch_L = img_L[:, rnd_h:rnd_h + crop_size[0], rnd_w:rnd_w + crop_size[1]] + patch_L2 = img_L2[:, rnd_h:rnd_h + crop_size[0], rnd_w:rnd_w + crop_size[1]] + + return patch_H, patch_L, patch_L2 + + + def __len__(self): + """ + copy-paste from basic naive dataset configuration + """ + return len(self.actual_dataset) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/brain.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/brain.py new file mode 100644 index 0000000000000000000000000000000000000000..b7958424e373bc61cbdfc52a0b4348d76cef7dd4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/brain.py @@ -0,0 +1,148 @@ +# Dataloader for abdominal images +import glob +import numpy as np +from .m4_utils import niftiio as nio +from .m4_utils import transform_utils as trans +from .m4_utils.abd_dataset_utils import get_normalize_op +from .m4_utils.transform_albu import get_albu_transforms, get_resize_transforms + +import torch +import os +from pdb import set_trace +from multiprocessing import Process +from .basic import BasicDataset + + +LABEL_NAME = ["bg", "NCR", "ED", "ET"] + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + +def normalize_instance(data, mean=None, std=None, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + if mean is None: + mean = data.mean() + if std is None: + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + +class BrainDataset(BasicDataset): + def __init__(self, mode, base_dir, image_size, + nclass, domains, aux_modality, **kwargs): + """ + Args: + mode: 'train', 'val', 'test', 'test_all' + transforms: naive data augmentations used by default. Photometric transformations slightly better than those configured by Zhang et al. (bigaug) + idx_pct: train-val-test split for source domain + extern_norm_fn: feeding data normalization functions from external, only concerns CT-MR cross domain scenario + """ + self.dataset_key = "brain" + transforms = get_albu_transforms(mode, image_size) + if isinstance(domains, str): + domains = [domains] + + super(BrainDataset, self).__init__(image_size, mode, + transforms, + base_dir, + domains, aux_modality, + nclass=nclass, + LABEL_NAME=LABEL_NAME, + filter_non_labeled=True, + **kwargs) + + def hwc_to_chw(self,img): + img = np.float32(img) + img = np.transpose(img, (2, 0, 1)) # [C, H, W] + img = torch.from_numpy( img.copy() ) + return img + + def perform_trans(self, img, mask, aux): + + T = self.albu_transform if self.is_train else self.test_resizer + buffer = T(image = img, mask=mask, image2=aux) # [0 - 255] + img, mask, aux = buffer['image'], buffer['mask'], buffer['image2'] + if len(mask.shape) == 2: + mask = mask[..., None] + + # if self.is_train: + # img, mask, aux = self.get_patch_from_img(img, mask, aux, crop_size=self.crop_size) # 192 + + return img, mask, aux + + + def __getitem__(self, index): + index = index % len(self.actual_dataset) + curr_dict = self.actual_dataset[index] # numpy + + # ----------------------- Extract Slice ----------------------- + img, mask, aux = curr_dict["img"], curr_dict["lb"], curr_dict["aux"] # H, W, C, [0 - 255] + domain, pid = curr_dict["domain"], curr_dict["pid"] + mean, std = curr_dict['mean'], curr_dict['std'] + aux_mean, aux_std = curr_dict['aux_mean'], curr_dict['aux_std'] + # max, min = img.max(), img.min() + std = 1 if std < 1e-3 else std + + # img = (img - mean) / std + ### 对input image和target image都做(x-mean)/std的归一化操作 + img, img_mean, img_std = normalize_instance(img, eps=1e-6) # mean=mean, std=std, + aux, aux_mean, aux_std = normalize_instance(aux, eps=1e-6) # mean=aux_mean, std=aux_std, + + ### clamp input to ensure training stability. + img = np.clip(img, -6, 6) + aux = np.clip(aux, -6, 6) + + + mask = mask[..., 0] + img, mask, aux = self.perform_trans(img, mask, aux) + img, mask, aux = map(lambda arr: self.hwc_to_chw(arr), [img, mask, aux]) + + img = np.clip(img, -6, 6) + aux = np.clip(aux, -6, 6) + + if self.tile_z_dim > 1 and self.input_window == 1 and self.num_channels == 3 : + img = img.repeat( [ self.tile_z_dim, 1, 1] ) + assert img.ndimension() == 3 + + data = {"img": img, "lb": mask, "aux": aux, + "img_mean": img_mean, "img_std": img_std, + "aux_mean": aux_mean, "aux_std": aux_std, + "is_start": curr_dict["is_start"], + "is_end": curr_dict["is_end"], + "nframe": np.int32(curr_dict["nframe"]), + "scan_id": curr_dict["scan_id"], + "z_id": curr_dict["z_id"], + "file_id": curr_dict["file_id"] + } + + return data + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/celeb.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/celeb.py new file mode 100644 index 0000000000000000000000000000000000000000..43db1aac54d58ed7cae60033e72e811c3467cc16 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/celeb.py @@ -0,0 +1,61 @@ +from comet_ml import Experiment +import math + + +from torch.utils import data +from pathlib import Path +from torchvision import transforms +from PIL import Image + + + + +class Dataset_Aug1(data.Dataset): + def __init__(self, folder, image_size, exts = ['jpg', 'jpeg', 'png']): + super().__init__() + self.folder = folder + self.image_size = image_size + self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] + + self.transform = transforms.Compose([ + transforms.Resize((int(image_size*1.12), int(image_size*1.12))), + transforms.RandomCrop(image_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Lambda(lambda t: (t * 2) - 1) + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + img = Image.open(path) + img = img.convert('RGB') + return self.transform(img) + + + +class Dataset(data.Dataset): + def __init__(self, folder, image_size, exts=['jpg', 'jpeg', 'png']): + super().__init__() + self.folder = folder + self.image_size = image_size + self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] + + self.transform = transforms.Compose([ + transforms.Resize((int(image_size*1.12), int(image_size*1.12))), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + transforms.Lambda(lambda t: (t * 2) - 1) + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + img = Image.open(path) + img = img.convert('RGB') + return self.transform(img) + \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/dicom_test.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/dicom_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ece18c9d53a7429805b18cab2c7b98273e5db9a6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/dicom_test.py @@ -0,0 +1,75 @@ +import pydicom +# pip install pydicom + + +def print_dicom_metadata(file_path): + # Read the DICOM file + dicom_data = pydicom.dcmread(file_path) + + # Print all metadata + for element in dicom_data: + # Retrieve the tag, name, and value + tag = element.tag + name = element.name + value = element.value + + # Handle different types of values + if isinstance(value, pydicom.multival.MultiValue): + # Join MultiValue elements into a single string + value = ", ".join(str(v) for v in value) + + elif isinstance(value, bytes): + # Decode bytes if possible, or represent them as hex + try: + value = value.decode('utf-8') + except UnicodeDecodeError: + value = value.hex() + + # Print the tag, name, and processed value + print(f"{tag} {name}: {value}") + +# Path to your DICOM file +dicom_file_path = "/gamedrive/Datasets/medical/Knee/fastMRI/knee_mri_clinical_seq_batch2/FB_476595____FB,1899398684/study_63e96492/MR2_dd8eb0e8/00031_852687759fd1a2c1.dcm" +dicom_file_path = "/gamedrive/Datasets/medical/Knee/fastMRI/knee_mri_clinical_seq_batch2/FB_476595____FB,1899398684/study_63e96492/MR3_e6e4d154/00013_15898f7eff8d4655.dcm" + + +dicom_data = pydicom.dcmread(dicom_file_path) +# MR2_dd8eb0e8/ MR3_e6e4d154/ MR4_71dd8cd8/ MR5_9419dbc1/ + +# Print metadata +# Extract the Series Description +series_description = dicom_data.get((0x0008, 0x103E), "Series Description not found").value +print(f"Series Description: {series_description}") + +dicom_folder = "/gamedrive/Datasets/medical/Knee/fastMRI/knee_mri_clinical_seq_batch2/" + +dict = {"MR2": [], "MR3":[], "MR4":[], "MR5":[], "MR6":[], "MR7":[], "MR8":[], "MR9":[]} + + +count = 0 +import os, glob +for mrs in glob.glob(f"{dicom_folder}/*/*/*"): + mr_name = mrs.split("/")[-1].split("_")[0] + print(mr_name) + + for root, dirs, files in os.walk(mrs): + for file in files[:1]: + if file.endswith(".dcm"): # Check if the file has a .dcm extension + dicom_file_path = os.path.join(root, file) + + # Read the DICOM file + ds = pydicom.dcmread(dicom_file_path) + + # Extract the Series Description, if available + series_description = ds.get((0x0008, 0x103E), None).value + if series_description: + print(f"File: {file} | Series Description: {series_description}") + else: + print(f"File: {file} | Series Description not found") + dict[mr_name].append(series_description) + count +=1 + if count == 200: + break + +for key, value in dict.items(): + print(f"{key}: {value} ") \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/fastmri.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/fastmri.py new file mode 100644 index 0000000000000000000000000000000000000000..68e3b22b8a78ef2c314b44a29798bb78ded7e726 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/fastmri.py @@ -0,0 +1,338 @@ +import csv +import os +import random +import xml.etree.ElementTree as etree +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +import pathlib + +import h5py +import numpy as np +import torch +import yaml +from torch.utils.data import Dataset +# from .transforms import build_transforms +from matplotlib import pyplot as plt +from dataset.m4_utils.transform_albu import get_albu_transforms + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +def fetch_dir(key, data_config_file=pathlib.Path("fastmri_dirs.yaml")): + """ + Data directory fetcher. + + This is a brute-force simple way to configure data directories for a + project. Simply overwrite the variables for `knee_path` and `brain_path` + and this function will retrieve the requested subsplit of the data for use. + + Args: + key (str): key to retrieve path from data_config_file. + data_config_file (pathlib.Path, + default=pathlib.Path("fastmri_dirs.yaml")): Default path config + file. + + Returns: + pathlib.Path: The path to the specified directory. + """ + if not data_config_file.is_file(): + default_config = dict( + knee_path="/home/jc3/Data/", + brain_path="/home/jc3/Data/", + ) + with open(data_config_file, "w") as f: + yaml.dump(default_config, f) + + raise ValueError(f"Please populate {data_config_file} with directory paths.") + + with open(data_config_file, "r") as f: + data_dir = yaml.safe_load(f)[key] + + data_dir = pathlib.Path(data_dir) + + if not data_dir.exists(): + raise ValueError(f"Path {data_dir} from {data_config_file} does not exist.") + + return data_dir + + +def et_query( + root: etree.Element, + qlist: Sequence[str], + namespace: str = "http://www.ismrm.org/ISMRMRD", +) -> str: + """ + ElementTree query function. + This can be used to query an xml document via ElementTree. It uses qlist + for nested queries. + Args: + root: Root of the xml to search through. + qlist: A list of strings for nested searches, e.g. ["Encoding", + "matrixSize"] + namespace: Optional; xml namespace to prepend query. + Returns: + The retrieved data as a string. + """ + s = "." + prefix = "ismrmrd_namespace" + + ns = {prefix: namespace} + + for el in qlist: + s = s + f"//{prefix}:{el}" + + value = root.find(s, ns) + if value is None: + raise RuntimeError("Element not found") + + return str(value.text) + +from dataset.m4_utils.transforms import build_transforms + +class SliceDataset(Dataset): + def __init__( + self, + root, + transform, + challenge, + input_normalize="mean_std", + image_size=(128, 128), + sample_rate=1, + mode='train', + debug=True, + ): + self.mode = mode + self.transforms = get_albu_transforms(mode, image_size) + self.input_normalize = input_normalize + + # challenge + if challenge not in ("singlecoil", "multicoil"): + raise ValueError('challenge should be either "singlecoil" or "multicoil"') + self.recons_key = ( + "reconstruction_esc" if challenge == "singlecoil" else "reconstruction_rss" + ) + # transform + self.transform = transform + + self.other_transform = build_transforms("random", + [1], + [1], mode) + self.examples = [] + + self.cur_path = root + print("dataroot = ", root) + if self.mode == "test": + self.csv_file = "./dataset/knee_data_split/singlecoil_train_split_less.csv" + else: + self.csv_file = "./dataset/knee_data_split/singlecoil_" + self.mode + "_split_less.csv" + + with open(self.csv_file, 'r') as f: + reader = csv.reader(f) + + id = 0 + if debug: + reader = list(reader)[:10] + + for row in reader: + pd_metadata, pd_num_slices = self._retrieve_metadata(os.path.join(self.cur_path, row[0] + '.h5')) + + pdfs_metadata, pdfs_num_slices = self._retrieve_metadata(os.path.join(self.cur_path, row[1] + '.h5')) + + for slice_id in range(min(pd_num_slices, pdfs_num_slices)): + self.examples.append( + (os.path.join(self.cur_path, row[0] + '.h5'), os.path.join(self.cur_path, row[1] + '.h5') + , slice_id, pd_metadata, pdfs_metadata, id)) + id += 1 + + if sample_rate < 1: + random.shuffle(self.examples) + num_examples = round(len(self.examples) * sample_rate) + + self.examples = self.examples[0:num_examples] + + self.down_transform = None + + def update_chunk(self): + pass + + def __len__(self): + return len(self.examples) + + def __getitem__(self, i): + + # read pd + pd_fname, pdfs_fname, slice, pd_metadata, pdfs_metadata, id = self.examples[i] + + with h5py.File(pd_fname, "r") as hf: + pd_kspace = hf["kspace"][slice] + pd_mask = np.asarray(hf["mask"]) if "mask" in hf else None + pd_target = hf[self.recons_key][slice] if self.recons_key in hf else None + attrs = dict(hf.attrs) + attrs.update(pd_metadata) + + if self.other_transform is None: + pd_sample = (pd_kspace, pd_mask, pd_target, attrs, pd_fname, slice) + else: + pd_sample = self.other_transform(pd_kspace, pd_mask, pd_target, attrs, pd_fname, slice) + + with h5py.File(pdfs_fname, "r") as hf: + pdfs_kspace = hf["kspace"][slice] + pdfs_mask = np.asarray(hf["mask"]) if "mask" in hf else None + pdfs_target = hf[self.recons_key][slice] if self.recons_key in hf else None + attrs = dict(hf.attrs) + attrs.update(pdfs_metadata) + + + if self.other_transform is None: + pdfs_sample = (pdfs_kspace, pdfs_mask, pdfs_target, attrs, pdfs_fname, slice) + else: + pdfs_sample = self.other_transform(pdfs_kspace, pdfs_mask, pdfs_target, attrs, pdfs_fname, slice) + + + # input size = 1.1693149e-05 7.2921634e-06 7.177928e-05 3.3911466e-08 + # print("input size = ", pdfs_target.mean(), pdfs_target.std(), pdfs_target.max(), pdfs_target.min()) + pdfs_mean = pdfs_sample[2] + pdfs_std = pdfs_sample[3] + pd_mean = pd_sample[2] + pd_std = pd_sample[3] + + pdfs_target = pdfs_sample[1].numpy() + pd_target = pd_sample[1].numpy() + + # print("pdfs_target:", pdfs_target.shape, pdfs_target.max(), pdfs_target.min()) + + + # print("pdf=", pdfs_target.shape, pdfs_target.max(), pdfs_target.min()) + # if self.input_normalize == "mean_std": + # + # # print("std:", pdfs_sample[3]) + # # print("mean:", pdfs_sample[2]) + # + # pdfs_target, pdfs_mean, pdfs_std = normalize_instance(pdfs_target, eps=1e-11) + # pd_target, pd_mean, pd_std = normalize_instance(pd_target, eps=1e-11) + # + # elif self.input_normalize == "min_max": + # pdfs_target = (pdfs_target - pdfs_target.min()) / (pdfs_target.max() - pdfs_target.min()) + # pd_target = (pd_target - pd_target.min()) / (pd_target.max() - pd_target.min()) + # pdfs_mean = 0 + # pdfs_std = 1 + # pd_mean = 0 + # pd_std = 1 + # else: + # raise ValueError(f"Unrecognized input normalization: {self.input_normalize}") + + + # return (pd_sample, pdfs_sample, id) + pdfs_main = True # PDWI as the auxiliary and FS-PDWI as the target + + if pdfs_main: + img = pdfs_target + aux = pd_target + img_mean = pdfs_mean + img_std = pdfs_std + aux_mean = pd_mean + aux_std = pd_std + else: + img = pd_target + aux = pdfs_target + img_mean = pd_mean + img_std = pd_std + aux_mean = pdfs_mean + aux_std = pdfs_std + + + data_dict = self.transforms(image=img, image2=aux) + img = data_dict['image'] + aux = data_dict['image2'] + + + # print("===img===:", img.shape, aux.shape, img.max(), img.min()) # (320, 320) (320, 320) + # print("===img_mean===:", img_mean, aux_mean) # 0.0 0.0 + + img = np.expand_dims(img, axis=0) + aux = np.expand_dims(aux, axis=0) + + data = {"img": img, "aux": aux, + "img_mean": img_mean, "img_std": img_std, + "aux_mean": aux_mean, "aux_std": aux_std, + } + return data + + + def _retrieve_metadata(self, fname): + with h5py.File(fname, "r") as hf: + et_root = etree.fromstring(hf["ismrmrd_header"][()]) + + enc = ["encoding", "encodedSpace", "matrixSize"] + enc_size = ( + int(et_query(et_root, enc + ["x"])), + int(et_query(et_root, enc + ["y"])), + int(et_query(et_root, enc + ["z"])), + ) + rec = ["encoding", "reconSpace", "matrixSize"] + recon_size = ( + int(et_query(et_root, rec + ["x"])), + int(et_query(et_root, rec + ["y"])), + int(et_query(et_root, rec + ["z"])), + ) + + lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"] + enc_limits_center = int(et_query(et_root, lims + ["center"])) + enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1 + + padding_left = enc_size[1] // 2 - enc_limits_center + padding_right = padding_left + enc_limits_max + + num_slices = hf["kspace"].shape[0] + + metadata = { + "padding_left": padding_left, + "padding_right": padding_right, + "encoding_size": enc_size, + "recon_size": recon_size, + } + + return metadata, num_slices + + +def build_dataset( mode='train', image_size=128, sample_rate=1): + assert mode in ['train', 'val', 'test'], 'unknown mode' + # transforms = build_transforms(args, mode) + + return SliceDataset(os.path.join(args.root_path, 'singlecoil_' + mode), image_size, 'singlecoil', sample_rate=sample_rate, mode=mode) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/fsm_dataloaders/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/fsm_dataloaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/fsm_dataloaders/hybrid_sparse.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/fsm_dataloaders/hybrid_sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..42e4a7e33c2204c13a1c4509897baf19e1fb07f1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/fsm_dataloaders/hybrid_sparse.py @@ -0,0 +1,156 @@ +from __future__ import print_function, division +import numpy as np +from glob import glob +import random +from skimage import transform + +import torch +from torch.utils.data import Dataset + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', transform=None): + + super().__init__() + self._base_dir = base_dir + self.im_ids = [] + self.images = [] + self.gts = [] + + if split=='train': + self._image_dir = self._base_dir + imagelist = glob(self._image_dir+"/*_ct.png") + imagelist=sorted(imagelist) + for image_path in imagelist: + gt_path = image_path.replace('ct', 't1') + self.images.append(image_path) + self.gts.append(gt_path) + + + elif split=='test': + self._image_dir = self._base_dir + imagelist = glob(self._image_dir + "/*_ct.png") + imagelist=sorted(imagelist) + for image_path in imagelist: + gt_path = image_path.replace('ct', 't1') + self.images.append(image_path) + self.gts.append(gt_path) + + self.transform = transform + + assert (len(self.images) == len(self.gts)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.images))) + + def __len__(self): + return len(self.images) + + + def __getitem__(self, index): + img_in, img, target_in, target= self._make_img_gt_point_pair(index) + sample = {'image_in': img_in, 'image':img, 'target_in': target_in, 'target': target} + # print("image in:", img_in.shape) + + if self.transform is not None: + sample = self.transform(sample) + + return sample + + + def _make_img_gt_point_pair(self, index): + # Read Image and Target + + # the default setting (i.e., rawdata.npz) is 4X64P + dd = np.load(self.images[index].replace('.png', '_raw_4X64P.npz')) + # print("images range:", dd['fbp'].max(), dd['ct'].max(), dd['under_t1'].max(), dd['t1'].max()) + _img_in = dd['fbp'] + _img_in[_img_in>0.6]=0.6 + _img_in = _img_in/0.6 + + _img = dd['ct'] + _img =(_img/1000*0.192+0.192) + _img[_img<0.0]=0.0 + _img[_img>0.6]=0.6 + _img = _img/0.6 + + _target_in = dd['under_t1'] + _target = dd['t1'] + + return _img_in, _img, _target_in, _target + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 400, 400 + crop_size = 384 + pad_size = (400-384)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + return {'ct_in': img_in, + 'ct': img, + 'mri_in': target_in, + 'mri': target} diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/fsm_dataloaders/kspace_subsample.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/fsm_dataloaders/kspace_subsample.py new file mode 100644 index 0000000000000000000000000000000000000000..fb5b5694d8fee8b35ba8394fae98fe2d3aa25759 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/fsm_dataloaders/kspace_subsample.py @@ -0,0 +1,287 @@ +""" +2023/10/16, +preprocess kspace data with the undersampling mask in the fastMRI project. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +## mri related +def mri_fourier_transform_2d(image, mask): + ''' + image: input tensor [B, H, W, C] + mask: mask tensor [H, W] + ''' + spectrum = torch.fft.fftn(image, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + spectrum = torch.fft.fftshift(spectrum, dim=(1, 2)) + # Downsample k-space + masked_spectrum = spectrum * mask[None, :, :, None] + return spectrum, masked_spectrum + + +## mri related +def mri_inver_fourier_transform_2d(spectrum): + ''' + image: input tensor [B, H, W, C] + ''' + spectrum = torch.fft.ifftshift(spectrum, dim=(1, 2)) + image = torch.fft.ifftn(spectrum, dim=(1, 2), norm='ortho') + + return image + + +def add_gaussian_noise(kspace, snr): + ### 根据SNR确定noise的放大比例 + num_pixels = kspace.shape[0]*kspace.shape[1]*kspace.shape[2]*kspace.shape[3] + psr = torch.sum(torch.abs(kspace.real)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + noise_r = torch.randn_like(kspace.real)*np.sqrt(pnr) + + psim = torch.sum(torch.abs(kspace.imag)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = torch.randn_like(kspace.imag)*np.sqrt(pnim) + + noise = noise_r + 1j*noise_im + noisy_kspace = kspace + noise + + return noisy_kspace + + +def mri_fft(raw_mri, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + spectrum = torch.fft.fftn(mri, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + kspace = torch.fft.fftshift(spectrum, dim=(1, 2)) + + if _SNR > 0: + noisy_kspace = add_gaussian_noise(kspace, _SNR) + else: + noisy_kspace = kspace + + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + + return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ + kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1) + + + +def undersample_mri(raw_mri, _MRIDOWN, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + if _MRIDOWN == "4X": + mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 + elif _MRIDOWN == "8X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + + ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + + shape = [240, 240, 1] + mask = ff(shape, seed=1337) + mask = mask[:, :, 0] # [1, 240] + # print("mask:", mask.shape) + # print("original MRI:", mri) + + # print("original MRI:", mri.shape) + ### under-sample the kspace data. + kspace, masked_kspace = mri_fourier_transform_2d(mri, mask) + ### add low-field noise to the kspace data. + if _SNR > 0: + noisy_kspace = add_gaussian_noise(masked_kspace, _SNR) + else: + noisy_kspace = masked_kspace + + ### conver the corrupted kspace data back to noisy MRI image. + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + + return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ + kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1), mask.unsqueeze(-1) + + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/h5_test.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/h5_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2eeb09503b238a43f94ae98589dd9f7152fc65 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/h5_test.py @@ -0,0 +1,16 @@ +import h5py, glob +import xml.etree.ElementTree as ET + + +h5_path = "/Users/haochen/Downloads/singlecoil_test/" # file1000056.h5 + +for h5 in glob.glob(h5_path + "*.h5"): + with h5py.File(h5, "r") as hf: + print("Keys: %s" % hf.keys()) + print("Attrs: %s" % hf.attrs.items()) + print("kspace shape:", hf['kspace'].shape) + # print("ismrmrd_header shape:", hf['ismrmrd_header'].shape) + for key, value in hf.attrs.items(): + print(f" {key}: {value}") + + print() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/knee_data_split/singlecoil_train_split_less.csv b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/knee_data_split/singlecoil_train_split_less.csv new file mode 100644 index 0000000000000000000000000000000000000000..d85707318750900b14a6e7100541242a60b7a310 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/knee_data_split/singlecoil_train_split_less.csv @@ -0,0 +1,227 @@ +file1000685,file1000568,0.301723929779229 +file1002273,file1000481,0.302226224199571 +file1000472,file1000142,0.304272730770318 +file1002186,file1000863,0.304812175768496 +file1002385,file1002518,0.305357274240413 +file1000981,file1000129,0.305533361411383 +file1001320,file1001948,0.306821514316368 +file1000633,file1002243,0.306892354331709 +file1001872,file1001294,0.308345907393103 +file1001474,file1001830,0.310481695157561 +file1001005,file1001283,0.310497722435023 +file1001690,file1001519,0.310709448786299 +file1002469,file1001811,0.31193137253455 +file1000914,file1000242,0.31237190359308 +file1002284,file1002012,0.315366393843169 +file1001721,file1001328,0.31735122361847 +file1000807,file1002334,0.320096908959039 +file1001944,file1002335,0.320272061156991 +file1002090,file1002431,0.320351887633851 +file1000499,file1002063,0.320786426659383 +file1001362,file1000509,0.32175341740359 +file1001421,file1000597,0.324291432700032 +file1000349,file1000321,0.324545110048573 +file1002123,file1001235,0.327142348994532 +file1001867,file1002086,0.328624781732941 +file1001007,file1001027,0.330759860300298 +file1001915,file1000088,0.331499371283099 +file1001661,file1000313,0.331905252950291 +file1000383,file1000307,0.339998107225229 +file1000116,file1000632,0.34069458535013 +file1002303,file1000173,0.343821267871409 +file1000306,file1001277,0.344751178043605 +file1000003,file1001922,0.346138116633394 +file1000109,file1000143,0.347632265547478 +file1001999,file1000115,0.348248659775587 +file1000089,file1000326,0.348964657514049 +file1001205,file1002232,0.349375610862454 +file1000557,file1000619,0.351305005151048 +file1001823,file1000778,0.352076809462453 +file1000806,file1001130,0.352659078122633 +file1000365,file1000351,0.352772816610486 +file1002374,file1001778,0.352974481603711 +file1002516,file1001910,0.359896103026675 +file1001200,file1000931,0.360070003966827 +file1001479,file1000952,0.360424533696936 +file1000850,file1001942,0.362632797518558 +file1001426,file1002143,0.363271909822866 +file1001304,file1001333,0.36404737582222 +file1000390,file1000518,0.364744579516818 +file1000830,file1002096,0.365897427529429 +file1000794,file1001856,0.365973692948894 +file1001266,file1001327,0.366395851089761 +file1001692,file1002352,0.36655953875445 +file1001564,file1001024,0.367284385415205 +file1001861,file1002050,0.36783497787384 +file1002066,file1002361,0.367964419694875 +file1001613,file1002087,0.368231014746024 +file1001931,file1000220,0.368847112914793 +file1000339,file1000554,0.370123905662701 +file1000754,file1002208,0.37031588493778 +file1001067,file1001956,0.371313060558732 +file1000101,file1001053,0.372141932838775 +file1002520,file1002409,0.372501194473693 +file1001459,file1001615,0.373295536945146 +file1001673,file1000508,0.376416667681519 +file1002201,file1001228,0.376680033570078 +file1000058,file1002449,0.376927627737029 +file1001748,file1001042,0.378067114701689 +file1001941,file1000376,0.37841176147662 +file1000801,file1002545,0.378423759459738 +file1000010,file1000535,0.38111194591455 +file1000882,file1002154,0.382223600234592 +file1001694,file1001297,0.382545161354354 +file1001992,file1002456,0.382664563820782 +file1001666,file1001773,0.382892588770697 +file1001629,file1002514,0.383417073960824 +file1002113,file1000738,0.385439884728523 +file1002221,file1000569,0.385903801966773 +file1002296,file1002117,0.387319754665673 +file1000693,file1001945,0.387855926202209 +file1001410,file1000223,0.391284037867147 +file1002071,file1001425,0.391497653794399 +file1002325,file1001259,0.391913965917762 +file1002430,file1001969,0.392256443856501 +file1002462,file1000708,0.393161981208355 +file1002358,file1001888,0.39427809496515 +file1000485,file1000753,0.395316199436001 +file1002357,file1001973,0.39564210237905 +file1002130,file1002041,0.395978941103639 +file1002569,file1000097,0.397496127623486 +file1002264,file1000148,0.397630184088734 +file1002381,file1001401,0.398105992102355 +file1000289,file1000585,0.399527637723015 +file1002368,file1001723,0.400243022234875 +file1002342,file1001319,0.400431803928825 +file1002170,file1001226,0.400632448147846 +file1001385,file1001758,0.400855988878681 +file1001732,file1002541,0.40091828863264 +file1001102,file1000762,0.400923140595936 +file1001470,file1000181,0.401353492516182 +file1000400,file1000884,0.401562860630016 +file1002293,file1002523,0.401800994807451 +file1000728,file1001654,0.402763341041675 +file1000582,file1001491,0.403451830806034 +file1000586,file1001521,0.403648293267187 +file1002287,file1001770,0.405194821414496 +file1000371,file1000159,0.405999000381268 +file1002356,file1002064,0.406519210876811 +file1000324,file1000590,0.407593694425997 +file1001622,file1001710,0.40759525378577 +file1002037,file1000403,0.407814136488744 +file1002444,file1000743,0.40943197761463 +file1001175,file1002088,0.410423663035312 +file1001391,file1000540,0.410854355646853 +file1002133,file1001186,0.411248429534111 +file1001229,file1001630,0.411355571792039 +file1002283,file1000402,0.411836769927671 +file1000627,file1000161,0.412089060388579 +file1001701,file1001402,0.412854774524637 +file1000795,file1000452,0.413448916432685 +file1000354,file1000947,0.41459642292987 +file1002043,file1002505,0.414863932355455 +file1001285,file1001113,0.418183757940871 +file1000170,file1001832,0.419441549204313 +file1002399,file1001500,0.419905873946513 +file1002439,file1000177,0.42054051043224 +file1001656,file1001217,0.420597020703942 +file1000296,file1000065,0.420845042251081 +file1000626,file1001623,0.42087934790355 +file1001767,file1000760,0.422315537515139 +file1000467,file1001246,0.422371268999111 +file1001033,file1000611,0.42425275873442 +file1002304,file1000221,0.425602179771197 +file1001737,file1001141,0.425716789218234 +file1001565,file1000559,0.426158561043574 +file1000249,file1000643,0.426541100077021 +file1002014,file1001109,0.426587840438723 +file1002006,file1000790,0.427829459781438 +file1000193,file1000750,0.428103808477214 +file1001993,file1001110,0.428186367615143 +file1002094,file1001814,0.428868578868176 +file1000098,file1001420,0.428968675677784 +file1000336,file1000211,0.430347427208789 +file1001498,file1002568,0.43204475404071 +file1001671,file1001106,0.432215802861284 +file1000426,file1002386,0.43283446816702 +file1001520,file1002481,0.434867670495723 +file1002189,file1001432,0.434924370194975 +file1001390,file1002554,0.435313848731387 +file1002166,file1001982,0.435387512979012 +file1001120,file1001006,0.435594761785839 +file1000149,file1001985,0.436289528591294 +file1001632,file1001008,0.436682374331417 +file1002567,file1001155,0.437221000601772 +file1000434,file1002195,0.438098100114814 +file1002532,file1001048,0.438500899539101 +file1001605,file1000927,0.438686659342641 +file1000479,file1000120,0.439587267995034 +file1002473,file1001388,0.439594997597548 +file1001108,file1002228,0.440528754793898 +file1002099,file1002056,0.440776843467602 +file1000191,file1002127,0.441114509542672 +file1000875,file1002494,0.441378135507993 +file1002161,file1000002,0.441912476744187 +file1002269,file1001220,0.442742296865228 +file1001295,file1001355,0.4435162405589 +file1001659,file1001023,0.444686151316673 +file1001857,file1001378,0.447500830900898 +file1001183,file1001370,0.447782748040587 +file1000428,file1000859,0.448328910257083 +file1000588,file1002227,0.448650488897259 +file1001098,file1000486,0.448862467740607 +file1001288,file1000408,0.450363676957042 +file1002097,file1001210,0.451126832474666 +file1000216,file1001082,0.451550143520946 +file1001746,file1001642,0.451781042569196 +file1002388,file1000204,0.451940333555972 +file1000021,file1000560,0.452234621797968 +file1000489,file1001545,0.452796032302523 +file1001116,file1000883,0.453096911915119 +file1001372,file1000561,0.45532542913335 +file1001276,file1000424,0.45534174289324 +file1000974,file1002098,0.455371894001872 +file1002566,file1002044,0.455937677517583 +file1000262,file1002046,0.456056330767294 +file1001619,file1001342,0.456559091350965 +file1000045,file1001616,0.457599407743834 +file1001468,file1002115,0.458095965024278 +file1001061,file1000233,0.460561351667266 +file1000558,file1000100,0.461094222462111 +file1000605,file1000691,0.461429521647285 +file1000640,file1000384,0.463383466503099 +file1000410,file1001358,0.463452482427773 +file1000851,file1001014,0.463558384057952 +file1001092,file1000138,0.463591264436099 +file1000061,file1002049,0.465778207162619 +file1001206,file1000983,0.466701211830884 +file1000256,file1000475,0.466865377968187 +file1002434,file1001387,0.467154181996099 +file1001036,file1000210,0.470404279499276 +file1001540,file1001860,0.472822271037545 +file1001244,file1001154,0.475076170733515 +file1000131,file1001526,0.475459563440874 +file1000180,file1002045,0.476814451110009 +file1001837,file1000637,0.478851985878026 +file1002425,file1001891,0.481451070031007 +file1001056,file1000682,0.482320170742015 +file1002276,file1000777,0.483452141843029 +file1001139,file1002544,0.487462418948035 +file1000548,file1001257,0.488098081542811 +file1000188,file1001286,0.488423105111001 +file1001879,file1000999,0.488449105381724 +file1001062,file1000231,0.48930683373911 +file1000040,file1001873,0.492070802214623 +file1002286,file1000066,0.493213986773381 +file1002474,file1002563,0.501584439120211 +file1000967,file1000563,0.502066261411662 +file1001307,file1002048,0.50460435259807 +file1000483,file1001699,0.511819026566198 +file1001528,file1000285,0.512629017841038 +file1001742,file1002371,0.513805213204644 +file1002397,file1000592,0.515406473057 +file1000069,file1000510,0.528220553613126 +file1001087,file1001300,0.536510449049583 +file1001991,file1000836,0.538145797125916 +file1001382,file1001806,0.538539506621535 +file1000111,file1001189,0.557690760784602 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/knee_data_split/singlecoil_val_split_less.csv b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/knee_data_split/singlecoil_val_split_less.csv new file mode 100644 index 0000000000000000000000000000000000000000..a1cbac5537562063359f4ac3e0985de51cb989b2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/knee_data_split/singlecoil_val_split_less.csv @@ -0,0 +1,45 @@ +file1000323,file1002538,0.30754967523156 +file1001458,file1001566,0.310512744537048 +file1000885,file1001059,0.318226346221521 +file1000464,file1000196,0.321465466968232 +file1000314,file1000178,0.327505552363568 +file1001163,file1001289,0.328954963947692 +file1000033,file1001191,0.330925609207301 +file1000976,file1000990,0.344036229323198 +file1001930,file1001834,0.345994076497818 +file1002546,file1001344,0.351762252794677 +file1000277,file1001429,0.353297786572139 +file1001893,file1001262,0.358064285890878 +file1000926,file1002067,0.360639004205491 +file1001650,file1002002,0.362186928073579 +file1001184,file1001655,0.362592305723707 +file1001497,file1001338,0.365599407221502 +file1001202,file1001365,0.3844323497275 +file1001126,file1002340,0.388929627976346 +file1001339,file1000291,0.391300537691403 +file1002187,file1001862,0.39883786878841 +file1000041,file1000591,0.39896683485823 +file1001064,file1001850,0.399687813966601 +file1001331,file1002214,0.400340820924839 +file1000831,file1000528,0.403582747590964 +file1000769,file1000538,0.405298051020298 +file1000182,file1001968,0.407646172205036 +file1002382,file1001651,0.410749052045234 +file1000660,file1000476,0.415423894745454 +file1002570,file1001726,0.424622351472032 +file1001585,file1000858,0.426738511964108 +file1000190,file1000593,0.428080574167047 +file1001170,file1001090,0.429987089825525 +file1002252,file1001440,0.432038842370013 +file1000697,file1001144,0.432558506761396 +file1001077,file1000000,0.441922503777368 +file1001381,file1001119,0.455418270809002 +file1001759,file1001851,0.460824505737749 +file1000635,file1002389,0.465674267492171 +file1001668,file1001689,0.467330511330772 +file1001221,file1000818,0.469630000354232 +file1001298,file1002145,0.473526387887779 +file1001763,file1001938,0.47398893150184 +file1001444,file1000942,0.48507438696692 +file1000735,file1002007,0.496530240691134 +file1000477,file1000280,0.528508000547834 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/kspace.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/kspace.py new file mode 100644 index 0000000000000000000000000000000000000000..dc79b77ed1e78f86ba46d55364d84cfa449060d0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/kspace.py @@ -0,0 +1,34 @@ +import torch +from torch import nn +import os +import cv2 +import gc +import numpy as np +from scipy.io import * +from scipy.fftpack import * + + + +# Fourier Transform +def fft_map(x): + fft_x = torch.fft.fftn(x) + fft_x_real = fft_x.real + fft_x_imag = fft_x.imag + + return fft_x_real, fft_x_imag + + +def undersample_kspace(x, mask, is_noise, noise_level, noise_var): + + fft = fft2(x) + fft = fftshift(fft) + fft = fft * mask + + if is_noise: + raise NotImplementedError + fft = fft + generate_gaussian_noise(fft, noise_level, noise_var) + + fft = ifftshift(fft) + x = ifft2(fft) + + return x \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/kspace_subsample.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/kspace_subsample.py new file mode 100644 index 0000000000000000000000000000000000000000..2634efacb70f129d616c385d17a3c8577ee9f9d4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/kspace_subsample.py @@ -0,0 +1,379 @@ +""" +2023/10/16, +preprocess kspace data with the undersampling mask in the fastMRI project. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +## mri related +def mri_fourier_transform_2d(image, mask): + ''' + image: input tensor [B, H, W, C] + mask: mask tensor [H, W] + ''' + spectrum = torch.fft.fftn(image, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + spectrum = torch.fft.fftshift(spectrum, dim=(1, 2)) + # Downsample k-space + masked_spectrum = spectrum * mask[None, :, :, None] + return spectrum, masked_spectrum + + +## mri related +def mri_inver_fourier_transform_2d(spectrum): + ''' + image: input tensor [B, H, W, C] + ''' + spectrum = torch.fft.ifftshift(spectrum, dim=(1, 2)) + image = torch.fft.ifftn(spectrum, dim=(1, 2), norm='ortho') + # image = torch.fft.fftshift(image, dim=[1, 2]) + + return image + + +def add_gaussian_noise(kspace, snr): + ### 根据SNR确定noise的放大比例 + num_pixels = kspace.shape[0]*kspace.shape[1]*kspace.shape[2]*kspace.shape[3] + psr = torch.sum(torch.abs(kspace.real)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + noise_r = torch.randn_like(kspace.real)*np.sqrt(pnr) + + psim = torch.sum(torch.abs(kspace.imag)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = torch.randn_like(kspace.imag)*np.sqrt(pnim) + + noise = noise_r + 1j*noise_im + noisy_kspace = kspace + noise + + return noisy_kspace + + +# def mri_fft(raw_mri, _SNR): +# mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) +# spectrum = torch.fft.fftn(mri, dim=(1, 2), norm='ortho') +# # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum +# kspace = torch.fft.fftshift(spectrum, dim=(1, 2)) + +# if _SNR > 0: +# noisy_kspace = add_gaussian_noise(kspace, _SNR) +# else: +# noisy_kspace = kspace + +# noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) +# noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + +# return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ +# kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1) + +def mri_fft_m4raw(lq_mri, hq_mri): + # breakpoint() + lq_mri = torch.tensor(lq_mri[0])[None, :, :, None].to(torch.float32) + lq_mri_spectrum = torch.fft.fftn(lq_mri, dim=(1, 2), norm='ortho') + lq_mri_spectrum = torch.fft.fftshift(lq_mri_spectrum, dim=(1, 2)) + + lq_mri = mri_inver_fourier_transform_2d(lq_mri_spectrum) + lq_mri = torch.cat([torch.real(lq_mri), torch.imag(lq_mri)], dim=-1) + lq_kspace = torch.cat([torch.real(lq_mri_spectrum), torch.imag(lq_mri_spectrum)], dim=-1) + + + hq_mri = torch.tensor(hq_mri[0])[None, :, :, None].to(torch.float32) + hq_mri_spectrum = torch.fft.fftn(hq_mri, dim=(1, 2), norm='ortho') + hq_mri_spectrum = torch.fft.fftshift(hq_mri_spectrum, dim=(1, 2)) + + hq_mri = mri_inver_fourier_transform_2d(hq_mri_spectrum) + hq_mri = torch.cat([torch.real(hq_mri), torch.imag(hq_mri)], dim=-1) + hq_kspace = torch.cat([torch.real(hq_mri_spectrum), torch.imag(hq_mri_spectrum)], dim=-1) + + # breakpoint() + return lq_kspace[0], lq_mri[0].permute(2, 0, 1), \ + hq_kspace[0], hq_mri[0].permute(2, 0, 1) + +def mri_fft(raw_mri, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + spectrum = torch.fft.fftn(mri, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + kspace = torch.fft.fftshift(spectrum, dim=(1, 2)) + + if _SNR > 0: + noisy_kspace = add_gaussian_noise(kspace, _SNR) + else: + noisy_kspace = kspace + + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.cat([torch.real(noisy_mri), torch.imag(noisy_mri)], dim=-1) + noisy_kspace = torch.cat([torch.real(noisy_kspace), torch.imag(noisy_kspace)], dim=-1) + + raw_ksapce = torch.cat([torch.real(kspace), torch.imag(kspace)], dim=-1) + raw_mri = mri_inver_fourier_transform_2d(kspace) + raw_mri = torch.cat([torch.real(raw_mri), torch.imag(raw_mri)], dim=-1) + + # breakpoint() + return noisy_kspace[0], noisy_mri[0].permute(2, 0, 1), \ + raw_ksapce[0], raw_mri[0].permute(2, 0, 1) + +def undersample_mri(raw_mri, _MRIDOWN, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + raw_spectrum = torch.fft.fftn(mri, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + raw_kspace = torch.fft.fftshift(raw_spectrum, dim=(1, 2)) + + if not _MRIDOWN == "0X": + if _MRIDOWN == "4X": + mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 + elif _MRIDOWN == "8X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + + ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + shape = [240, 240, 1] + mask = ff(shape, seed=1337) + mask = mask[:, :, 0] # [1, 240] + + ### under-sample the kspace data. + kspace, masked_kspace = mri_fourier_transform_2d(mri, mask) + + ### add low-field noise to the kspace data. + if _SNR > 0: + noisy_kspace = add_gaussian_noise(masked_kspace, _SNR) + else: + noisy_kspace = masked_kspace + + else: + if _SNR > 0: + noisy_kspace = add_gaussian_noise(raw_kspace, _SNR) + else: + noisy_kspace = raw_kspace + + mask = torch.ones([1,240]) + # breakpoint() + ### conver the corrupted kspace data back to noisy MRI image. + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.cat([torch.real(noisy_mri), torch.imag(noisy_mri)], dim=-1) + + noisy_kspace_ = torch.cat([torch.real(noisy_kspace), torch.imag(noisy_kspace)], dim=-1) + + raw_mri = mri_inver_fourier_transform_2d(raw_kspace) + raw_mri = torch.cat([torch.real(raw_mri), torch.imag(raw_mri)], dim=-1) + raw_kspace = torch.cat([torch.real(raw_kspace), torch.imag(raw_kspace)], dim=-1) + + return noisy_kspace_[0], noisy_mri[0].permute(2, 0, 1), \ + raw_kspace[0], raw_mri[0].permute(2, 0, 1), mask.unsqueeze(-1) + +# def undersample_mri(raw_mri, _MRIDOWN, _SNR): +# mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) +# if _MRIDOWN == "4X": +# mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 +# elif _MRIDOWN == "8X": +# mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + +# ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + +# shape = [240, 240, 1] +# mask = ff(shape, seed=1337) +# mask = mask[:, :, 0] # [1, 240] +# # print("mask:", mask.shape) +# # print("original MRI:", mri) + +# # print("original MRI:", mri.shape) +# ### under-sample the kspace data. +# kspace, masked_kspace = mri_fourier_transform_2d(mri, mask) +# ### add low-field noise to the kspace data. +# if _SNR > 0: +# noisy_kspace = add_gaussian_noise(masked_kspace, _SNR) +# else: +# noisy_kspace = masked_kspace + +# ### conver the corrupted kspace data back to noisy MRI image. +# noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) +# noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + +# return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ +# kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1), mask.unsqueeze(-1) + + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/m4_utils.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/m4_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0d76246b0c8fb757b6b589b7f889e0316696d0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/m4_utils.py @@ -0,0 +1,272 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +import numpy as np + +def complex_mul(x, y): + """ + Complex multiplication. + + This multiplies two complex tensors assuming that they are both stored as + real arrays with the last dimension being the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == y.shape[-1] == 2 + re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] + im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] + + return torch.stack((re, im), dim=-1) + + +def complex_conj(x): + """ + Complex conjugate. + + This applies the complex conjugate assuming that the input array has the + last dimension as the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == 2 + + return torch.stack((x[..., 0], -x[..., 1]), dim=-1) + + +# def fft2c(data): +# """ +# Apply centered 2 dimensional Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The FFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# data = torch.fft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + +# def ifft2c(data): +# """ +# Apply centered 2-dimensional Inverse Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The IFFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# # data = torch.ifft(data, 2, normalized=True) +# data = torch.fft.ifft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + + +def fft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2 dimensional Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.fft``. + + Returns: + The FFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.fftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.ifft``. + + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + + + + +def complex_abs(data): + """ + Compute the absolute value of a complex valued input tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Absolute value of data. + """ + assert data.size(-1) == 2 + + return (data ** 2).sum(dim=-1).sqrt() + + + +def complex_abs_numpy(data): + assert data.shape[-1] == 2 + + return np.sqrt(np.sum(data ** 2, axis=-1)) + + +def complex_abs_sq(data):#multi coil + """ + Compute the squared absolute value of a complex tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Squared absolute value of data. + """ + assert data.size(-1) == 2 + return (data ** 2).sum(dim=-1) + + +# Helper functions + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data + """ + data = data.numpy() + return data[..., 0] + 1j * data[..., 1] diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/m4raw_dataloader.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/m4raw_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..0df4823e792595b4fcf066350c62ea30c02ec443 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/m4raw_dataloader.py @@ -0,0 +1,488 @@ + +from __future__ import print_function, division +from typing import Dict, NamedTuple, Optional, Sequence, Tuple, Union +import sys +sys.path.append('.') +from glob import glob +import os +os.environ['OPENBLAS_NUM_THREADS'] = '1' +import numpy as np +import torch +from torch.utils.data import Dataset + +import h5py +from matplotlib import pyplot as plt +from dataloaders.utils import ifft2c, fft2c, complex_abs +from dataloaders.kspace_subsample import create_mask_for_mask_type + +import argparse +from torch.utils.data import DataLoader +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + +# def normal(x): +# y = np.zeros_like(x) +# for i in range(y.shape[0]): +# x_min = x[i].min() +# x_max = x[i].max() +# y[i] = (x[i] - x_min)/(x_max-x_min) +# return y + + + +def undersample_mri(kspace, _MRIDOWN): + # print("kspace shape:", kspace.shape) ## [18, 4, 256, 256, 2] + + if _MRIDOWN == "4X": + mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 + elif _MRIDOWN == "8X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + + ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + + shape = [256, 256, 1] + mask = ff(shape, seed=1337) ## [1, 256, 1] + + mask = mask[:, :, 0] # [1, 256] + + masked_kspace = kspace * mask[None, None, :, :, None] + + return masked_kspace, mask.unsqueeze(-1) + +def apply_mask(data, mask_func, seed=None, padding=None): + """ + Subsample given k-space by multiplying with a mask. + + Args: + data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where + dimensions -3 and -2 are the spatial dimensions, and the final dimension has size + 2 (for complex values). + mask_func (callable): A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed (int or 1-d array_like, optional): Seed for the random number generator. + + Returns: + (tuple): tuple containing: + masked data (torch.Tensor): Subsampled k-space data + mask (torch.Tensor): The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + +def complex_center_crop(data, shape): + """ + Apply a center crop to the input image or batch of complex images. + + Args: + data (torch.Tensor): The complex input tensor to be center cropped. It + should have at least 3 dimensions and the cropping is applied along + dimensions -3 and -2 and the last dimensions should have a size of + 2. + shape (int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image + """ + assert 0 < shape[0] <= data.shape[-3] + assert 0 < shape[1] <= data.shape[-2] + + w_from = (data.shape[-3] - shape[0]) // 2 #80 + h_from = (data.shape[-2] - shape[1]) // 2 #80 + w_to = w_from + shape[0] #240 + h_to = h_from + shape[1] #240 + + return data[..., w_from:w_to, h_from:h_to, :] + + + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Args: + data (np.array): Input numpy array. + + Returns: + torch.Tensor: PyTorch version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data) + +def rss(data, dim=0): + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Args: + data (torch.Tensor): The input tensor + dim (int): The dimensions along which to apply the RSS transform + + Returns: + torch.Tensor: The RSS value. + """ + return torch.sqrt((data ** 2).sum(dim)) + +def read_h5(file_name, _MRIDOWN): + crop_size=[240,240] + + hf = h5py.File(file_name) + volume_kspace = hf['kspace'][()] + + slice_kspace = volume_kspace + slice_kspace = to_tensor(slice_kspace) + import imageio as io + + + + # masked_kspace, mask = apply_mask(slice_kspace, mask_func, seed=123456) + masked_kspace, mask = undersample_mri(slice_kspace, _MRIDOWN) + lq_image = ifft2c(masked_kspace) + lq_image = complex_center_crop(lq_image, crop_size) + lq_image = complex_abs(lq_image) + lq_image = rss(lq_image, dim=1) + # breakpoint() + # io.imsave('lq_image.png', lq_image[0].numpy().astype(np.uint8)) + lq_image_list=[] + mean_list=[] + std_list=[] + for i in range(lq_image.shape[0]): + image, mean, std = normalize_instance(lq_image[i], eps=1e-11) + image = image.clamp(-6, 6) + lq_image_list.append(image) + mean_list.append(mean) + std_list.append(std) + + target = ifft2c(slice_kspace) + target = complex_center_crop(target, crop_size) + target = complex_abs(target) + target = rss(target, dim=1) + # io.imsave('target1.png', target[10].numpy().astype(np.uint8)) + # breakpoint() + target_list=[] + + for i in range(lq_image.shape[0]): + target_slice = normalize(target[i], mean_list[i], std_list[i], eps=1e-11) + target_slice = target_slice.clamp(-6, 6) + target_list.append(target_slice) + + return torch.stack(lq_image_list), torch.stack(target_list), torch.stack(mean_list), torch.stack(std_list) + + + + +class M4Raw_TrainSet(Dataset): + def __init__(self, args): + # mask_func = create_mask_for_mask_type( + # args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + # ) + + self._MRIDOWN = args.MRIDOWN + self.input_normalize = args.input_normalize + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_train', '*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_train', '*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + T2_input_list = [input_list1, input_list2, input_list3] + + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 240, 240]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_mean = np.zeros([len(input_list2),len(T2_input_list), 18]) + self.T2_std = np.zeros([len(input_list2),len(T2_input_list), 18]) + + print('TrainSet loading...') + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + _, self.T1_images[j][i], _, _ = read_h5(path, self._MRIDOWN) + + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_images[j][i], self.T2_masked_images[j][i], self.T2_mean[j][i], self.T2_std[j][i] = read_h5(path, self._MRIDOWN) + + + + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print('Finish loading') + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),240,240) + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),240,240) + self.T1_labels = self.T1_labels.reshape(-1,1,240,240) + self.T2_labels = self.T2_labels.reshape(-1,1,240,240) + self.T2_mean = self.T2_mean.reshape(-1,3) + self.T2_std = self.T2_std.reshape(-1,3) + print("Train data shape:", self.T1_images.shape) + # breakpoint() + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) ## 每次都是从三个repetition中选择一个作为input. + # choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + T1_images = T1_images[choices] + T2_images = T2_images[choices] + t2_mean = self.T2_mean[idx][choices] + t2_std = self.T2_std[idx][choices] + # breakpoint() + # import imageio as io + # io.imsave('T1_images.png', (T1_images[0]*255).astype(np.uint8)) + # io.imsave('T1_labels.png', (T1_labels[0]*255).astype(np.uint8)) + # io.imsave('T2_images.png', (T2_images[0]*255).astype(np.uint8)) + # io.imsave('T2_labels.png', (T2_labels[0]*255).astype(np.uint8)) + # breakpoint() + + t1_in=T1_images + t1=T1_labels + t2_in=T2_images + t2=T2_labels + + sample_stats = {"t2_mean": t2_mean, "t2_std": t2_std} + + + # breakpoint() + sample = {'image_in': t1_in, + 'image': t1, + + 'target_in': t2_in, + 'target': t2} + + return sample, sample_stats + + + +class M4Raw_TestSet(Dataset): + def __init__(self, args): + # mask_func = create_mask_for_mask_type( + # args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + # ) + self._MRIDOWN = args.MRIDOWN + self.input_normalize = args.input_normalize + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_val' + '*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_val', '*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + T2_input_list = [input_list1,input_list2,input_list3] + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 240, 240]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_mean = np.zeros([len(input_list2),len(T2_input_list), 18]) + self.T2_std = np.zeros([len(input_list2),len(T2_input_list), 18]) + print('TrainSet loading...') + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + _, self.T1_images[j][i], _, _ = read_h5(path, self._MRIDOWN) + + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_images[j][i], self.T2_masked_images[j][i], self.T2_mean[j][i], self.T2_std[j][i] = read_h5(path, self._MRIDOWN) + + + + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print('Finish loading') + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),240,240) + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),240,240) + self.T1_labels = self.T1_labels.reshape(-1,1,240,240) + self.T2_labels = self.T2_labels.reshape(-1,1,240,240) + self.T2_mean = self.T2_mean.reshape(-1,3) + self.T2_std = self.T2_std.reshape(-1,3) + print("Train data shape:", self.T1_images.shape) + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + # choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) ## 每次都是从三个repetition中选择一个作为input. + choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + T1_images = T1_images[choices] + T2_images = T2_images[choices] + t2_mean = self.T2_mean[idx][choices] + t2_std = self.T2_std[idx][choices] + # breakpoint() + # import imageio as io + # io.imsave('T1_images.png', (T1_images[0]*255).astype(np.uint8)) + # io.imsave('T1_labels.png', (T1_labels[0]*255).astype(np.uint8)) + # io.imsave('T2_images.png', (T2_images[0]*255).astype(np.uint8)) + # io.imsave('T2_labels.png', (T2_labels[0]*255).astype(np.uint8)) + # breakpoint() + + t1_in=T1_images + t1=T1_labels + t2_in=T2_images + t2=T2_labels + + sample_stats = {"t2_mean": t2_mean, "t2_std": t2_std} + + + # breakpoint() + sample = {'image_in': t1_in, + 'image': t1, + + 'target_in': t2_in, + 'target': t2} + + return sample, sample_stats + + + +def compute_metrics(image, labels): + MSE = mean_squared_error(labels, image)/np.var(labels) + PSNR = peak_signal_noise_ratio(labels, image) + SSIM = structural_similarity(labels, image) + + # print("metrics:", MSE, PSNR, SSIM) + + return MSE, PSNR, SSIM + +def complex_abs_eval(data): + return (data[0:1, :, :] ** 2 + data[1:2, :, :] ** 2).sqrt() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--root_path', type=str, default='/data/qic99/MRI_recon/') + parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') + parser.add_argument('--kspace_refine', type=str, default='False', help='whether use the image reconstructed from kspace network.') + args = parser.parse_args() + + db_test = M4Raw_TestSet(args) + testloader = DataLoader(db_test, batch_size=4, shuffle=False, num_workers=4, pin_memory=True) + + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all = [], [], [] + + save_dir = "./visualize_images/" + + for i_batch, sampled_batch in enumerate(testloader): + # t1_in, t2_in = sampled_batch['t1_in'].cuda(), sampled_batch['t2_in'].cuda() + # t1, t2 = sampled_batch['t1_labels'].cuda(), sampled_batch['t2_labels'].cuda() + + t1_in, t2_in = sampled_batch['ref_image_sub'].cuda(), sampled_batch['tag_image_sub'].cuda() + t1, t2 = sampled_batch['ref_image_full'].cuda(), sampled_batch['tag_image_full'].cuda() + + # breakpoint() + for j in range(t1_in.shape[0]): + # t1_in_img = (np.clip(complex_abs_eval(t1_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(complex_abs_eval(t1[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(complex_abs_eval(t2_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(complex_abs_eval(t2[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # breakpoint() + t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + # t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + + # print(t1_in_img.shape, t1_img.shape) + t1_MSE, t1_PSNR, t1_SSIM = compute_metrics(t1_in_img, t1_img) + t2_MSE, t2_PSNR, t2_SSIM = compute_metrics(t2_in_img, t2_img) + + + t1_MSE_all.append(t1_MSE) + t1_PSNR_all.append(t1_PSNR) + t1_SSIM_all.append(t1_SSIM) + + t2_MSE_all.append(t2_MSE) + t2_PSNR_all.append(t2_PSNR) + t2_SSIM_all.append(t2_SSIM) + + + # print("t1_PSNR:", t1_PSNR_all) + print("t1_PSNR:", round(np.array(t1_PSNR_all).mean(), 4)) + print("t1_NMSE:", round(np.array(t1_MSE_all).mean(), 4)) + print("t1_SSIM:", round(np.array(t1_SSIM_all).mean(), 4)) + + print("t2_PSNR:", round(np.array(t2_PSNR_all).mean(), 4)) + print("t2_NMSE:", round(np.array(t2_MSE_all).mean(), 4)) + print("t2_SSIM:", round(np.array(t2_SSIM_all).mean(), 4)) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/m4raw_std_dataloader.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/m4raw_std_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..c1ee7835f421a22ed9a8514884bc95e1498dc378 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/m4raw_std_dataloader.py @@ -0,0 +1,487 @@ + +from __future__ import print_function, division +from typing import Dict, NamedTuple, Optional, Sequence, Tuple, Union +import sys +sys.path.append('.') +from glob import glob +import os +os.environ['OPENBLAS_NUM_THREADS'] = '1' +import numpy as np +import torch +from torch.utils.data import Dataset + +import h5py +from matplotlib import pyplot as plt +from dataloaders.m4_utils import ifft2c, fft2c, complex_abs +from dataloaders.kspace_subsample import create_mask_for_mask_type + +import argparse +from torch.utils.data import DataLoader +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + +# def normal(x): +# y = np.zeros_like(x) +# for i in range(y.shape[0]): +# x_min = x[i].min() +# x_max = x[i].max() +# y[i] = (x[i] - x_min)/(x_max-x_min) +# return y + + + +# def undersample_mri(kspace, _MRIDOWN): +# # print("kspace shape:", kspace.shape) ## [18, 4, 256, 256, 2] + +# if _MRIDOWN == "4X": +# mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 +# elif _MRIDOWN == "8X": +# mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + +# ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + +# shape = [256, 256, 1] +# mask = ff(shape, seed=1337) ## [1, 256, 1] + +# mask = mask[:, :, 0] # [1, 256] + +# masked_kspace = kspace * mask[None, None, :, :, None] + +# return masked_kspace, mask.unsqueeze(-1) + +def apply_mask(data, mask_func, seed=None, padding=None): + """ + Subsample given k-space by multiplying with a mask. + + Args: + data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where + dimensions -3 and -2 are the spatial dimensions, and the final dimension has size + 2 (for complex values). + mask_func (callable): A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed (int or 1-d array_like, optional): Seed for the random number generator. + + Returns: + (tuple): tuple containing: + masked data (torch.Tensor): Subsampled k-space data + mask (torch.Tensor): The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + +def complex_center_crop(data, shape): + """ + Apply a center crop to the input image or batch of complex images. + + Args: + data (torch.Tensor): The complex input tensor to be center cropped. It + should have at least 3 dimensions and the cropping is applied along + dimensions -3 and -2 and the last dimensions should have a size of + 2. + shape (int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image + """ + assert 0 < shape[0] <= data.shape[-3] + assert 0 < shape[1] <= data.shape[-2] + + w_from = (data.shape[-3] - shape[0]) // 2 #80 + h_from = (data.shape[-2] - shape[1]) // 2 #80 + w_to = w_from + shape[0] #240 + h_to = h_from + shape[1] #240 + + return data[..., w_from:w_to, h_from:h_to, :] + + + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Args: + data (np.array): Input numpy array. + + Returns: + torch.Tensor: PyTorch version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data) + +def rss(data, dim=0): + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Args: + data (torch.Tensor): The input tensor + dim (int): The dimensions along which to apply the RSS transform + + Returns: + torch.Tensor: The RSS value. + """ + return torch.sqrt((data ** 2).sum(dim)) + +def read_h5(file_name, mask_func): + crop_size=[240,240] + + hf = h5py.File(file_name) + volume_kspace = hf['kspace'][()] + + slice_kspace = volume_kspace + slice_kspace = to_tensor(slice_kspace) + import imageio as io + + + + masked_kspace, mask = apply_mask(slice_kspace, mask_func, seed=123456) + lq_image = ifft2c(masked_kspace) + lq_image = complex_center_crop(lq_image, crop_size) + lq_image = complex_abs(lq_image) + lq_image = rss(lq_image, dim=1) + # breakpoint() + # io.imsave('lq_image.png', lq_image[0].numpy().astype(np.uint8)) + lq_image_list=[] + mean_list=[] + std_list=[] + for i in range(lq_image.shape[0]): + image, mean, std = normalize_instance(lq_image[i], eps=1e-11) + image = image.clamp(-6, 6) + lq_image_list.append(image) + mean_list.append(mean) + std_list.append(std) + + target = ifft2c(slice_kspace) + target = complex_center_crop(target, crop_size) + target = complex_abs(target) + target = rss(target, dim=1) + # io.imsave('target1.png', target[10].numpy().astype(np.uint8)) + # breakpoint() + target_list=[] + + for i in range(lq_image.shape[0]): + target_slice = normalize(target[i], mean_list[i], std_list[i], eps=1e-11) + target_slice = target_slice.clamp(-6, 6) + target_list.append(target_slice) + + return torch.stack(lq_image_list), torch.stack(target_list), torch.stack(mean_list), torch.stack(std_list) + + + + +class M4Raw_TrainSet(Dataset): + def __init__(self, args): + mask_func = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + + self._MRIDOWN = args.MRIDOWN + self.input_normalize = args.input_normalize + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_train', '*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_train', '*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + T2_input_list = [input_list1, input_list2, input_list3] + + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 240, 240]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_mean = np.zeros([len(input_list2),len(T2_input_list), 18]) + self.T2_std = np.zeros([len(input_list2),len(T2_input_list), 18]) + + print('TrainSet loading...') + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + _, self.T1_images[j][i], _, _ = read_h5(path, mask_func) + + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_images[j][i], self.T2_masked_images[j][i], self.T2_mean[j][i], self.T2_std[j][i] = read_h5(path, mask_func) + + + + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print('Finish loading') + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),240,240) + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),240,240) + self.T1_labels = self.T1_labels.reshape(-1,1,240,240) + self.T2_labels = self.T2_labels.reshape(-1,1,240,240) + self.T2_mean = self.T2_mean.reshape(-1,3) + self.T2_std = self.T2_std.reshape(-1,3) + print("Train data shape:", self.T1_images.shape) + # breakpoint() + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) ## 每次都是从三个repetition中选择一个作为input. + # choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + T1_images = T1_images[choices] + T2_images = T2_images[choices] + t2_mean = self.T2_mean[idx][choices] + t2_std = self.T2_std[idx][choices] + # breakpoint() + # import imageio as io + # io.imsave('T1_images.png', (T1_images[0]*255).astype(np.uint8)) + # io.imsave('T1_labels.png', (T1_labels[0]*255).astype(np.uint8)) + # io.imsave('T2_images.png', (T2_images[0]*255).astype(np.uint8)) + # io.imsave('T2_labels.png', (T2_labels[0]*255).astype(np.uint8)) + # breakpoint() + + t1_in=T1_images + t1=T1_labels + t2_in=T2_images + t2=T2_labels + + sample_stats = {"t2_mean": t2_mean, "t2_std": t2_std} + + + # breakpoint() + sample = {'image_in': t1_in, + 'image': t1, + + 'target_in': t2_in, + 'target': t2} + + return sample, sample_stats + + + +class M4Raw_TestSet(Dataset): + def __init__(self, args): + mask_func = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + + self.input_normalize = args.input_normalize + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_val' + '*_T102.h5'))) + input_list2 = [path.replace('_T102.h5','_T101.h5') for path in input_list1] + input_list3 = [path.replace('_T102.h5','_T103.h5') for path in input_list1] + T1_input_list = [input_list1, input_list2, input_list3] + + input_list1 = sorted(glob(os.path.join(args.root_path, 'multicoil_val', '*_T202.h5'))) + input_list2 = [path.replace('_T202.h5','_T201.h5') for path in input_list1] + input_list3 = [path.replace('_T202.h5','_T203.h5') for path in input_list1] + T2_input_list = [input_list1,input_list2,input_list3] + + self.T1_input_list = T1_input_list + self.T2_input_list = T2_input_list + self.T1_images = np.zeros([len(input_list1),len(T1_input_list), 18, 240, 240]) + self.T2_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_masked_images = np.zeros([len(input_list2),len(T2_input_list), 18, 240, 240]) + self.T2_mean = np.zeros([len(input_list2),len(T2_input_list), 18]) + self.T2_std = np.zeros([len(input_list2),len(T2_input_list), 18]) + print('TrainSet loading...') + for i in range(len(self.T1_input_list)): + for j, path in enumerate(T1_input_list[i]): + _, self.T1_images[j][i], _, _ = read_h5(path, mask_func) + + self.T1_labels = np.mean(self.T1_images, axis=1) + + for i in range(len(self.T2_input_list)): + for j, path in enumerate(T2_input_list[i]): + self.T2_images[j][i], self.T2_masked_images[j][i], self.T2_mean[j][i], self.T2_std[j][i] = read_h5(path, mask_func) + + + + self.T2_labels = np.mean(self.T2_images, axis=1) + self.T2_images = self.T2_masked_images + + print('Finish loading') + + self.T1_images = self.T1_images.transpose(0,2,1,3,4).reshape(-1,len(T1_input_list),240,240) + self.T2_images = self.T2_images.transpose(0,2,1,3,4).reshape(-1,len(T2_input_list),240,240) + self.T1_labels = self.T1_labels.reshape(-1,1,240,240) + self.T2_labels = self.T2_labels.reshape(-1,1,240,240) + self.T2_mean = self.T2_mean.reshape(-1,3) + self.T2_std = self.T2_std.reshape(-1,3) + print("Train data shape:", self.T1_images.shape) + + + def __len__(self): + return len(self.T1_images) + + def __getitem__(self, idx): + + T1_images = self.T1_images[idx] + T2_images = self.T2_images[idx] + T1_labels = self.T1_labels[idx] + T2_labels = self.T2_labels[idx] + # choices = np.random.choice([i for i in range(len(self.T1_input_list))],1) ## 每次都是从三个repetition中选择一个作为input. + choices = np.random.choice([0],1) ## 用第一个repetition作为输入图像进行测试 + T1_images = T1_images[choices] + T2_images = T2_images[choices] + t2_mean = self.T2_mean[idx][choices] + t2_std = self.T2_std[idx][choices] + # breakpoint() + # import imageio as io + # io.imsave('T1_images.png', (T1_images[0]*255).astype(np.uint8)) + # io.imsave('T1_labels.png', (T1_labels[0]*255).astype(np.uint8)) + # io.imsave('T2_images.png', (T2_images[0]*255).astype(np.uint8)) + # io.imsave('T2_labels.png', (T2_labels[0]*255).astype(np.uint8)) + # breakpoint() + + t1_in=T1_images + t1=T1_labels + t2_in=T2_images + t2=T2_labels + + sample_stats = {"t2_mean": t2_mean, "t2_std": t2_std} + + + # breakpoint() + sample = {'image_in': t1_in, + 'image': t1, + + 'target_in': t2_in, + 'target': t2} + + return sample, sample_stats + + + +def compute_metrics(image, labels): + MSE = mean_squared_error(labels, image)/np.var(labels) + PSNR = peak_signal_noise_ratio(labels, image) + SSIM = structural_similarity(labels, image) + + # print("metrics:", MSE, PSNR, SSIM) + + return MSE, PSNR, SSIM + +def complex_abs_eval(data): + return (data[0:1, :, :] ** 2 + data[1:2, :, :] ** 2).sqrt() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--root_path', type=str, default='/data/qic99/MRI_recon/') + parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') + parser.add_argument('--kspace_refine', type=str, default='False', help='whether use the image reconstructed from kspace network.') + args = parser.parse_args() + + db_test = M4Raw_TestSet(args) + testloader = DataLoader(db_test, batch_size=4, shuffle=False, num_workers=4, pin_memory=True) + + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all = [], [], [] + + save_dir = "./visualize_images/" + + for i_batch, sampled_batch in enumerate(testloader): + # t1_in, t2_in = sampled_batch['t1_in'].cuda(), sampled_batch['t2_in'].cuda() + # t1, t2 = sampled_batch['t1_labels'].cuda(), sampled_batch['t2_labels'].cuda() + + t1_in, t2_in = sampled_batch['ref_image_sub'].cuda(), sampled_batch['tag_image_sub'].cuda() + t1, t2 = sampled_batch['ref_image_full'].cuda(), sampled_batch['tag_image_full'].cuda() + + # breakpoint() + for j in range(t1_in.shape[0]): + # t1_in_img = (np.clip(complex_abs_eval(t1_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(complex_abs_eval(t1[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(complex_abs_eval(t2_in[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(complex_abs_eval(t2[j])[0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # breakpoint() + t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + # t1_in_img = (np.clip(t1_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t1_img = (np.clip(t1[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_in_img = (np.clip(t2_in[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + # t2_img = (np.clip(t2[j][0].cpu().numpy(), 0, 1) * 255).astype(np.uint8) + + + # print(t1_in_img.shape, t1_img.shape) + t1_MSE, t1_PSNR, t1_SSIM = compute_metrics(t1_in_img, t1_img) + t2_MSE, t2_PSNR, t2_SSIM = compute_metrics(t2_in_img, t2_img) + + + t1_MSE_all.append(t1_MSE) + t1_PSNR_all.append(t1_PSNR) + t1_SSIM_all.append(t1_SSIM) + + t2_MSE_all.append(t2_MSE) + t2_PSNR_all.append(t2_PSNR) + t2_SSIM_all.append(t2_SSIM) + + + # print("t1_PSNR:", t1_PSNR_all) + print("t1_PSNR:", round(np.array(t1_PSNR_all).mean(), 4)) + print("t1_NMSE:", round(np.array(t1_MSE_all).mean(), 4)) + print("t1_SSIM:", round(np.array(t1_SSIM_all).mean(), 4)) + + print("t2_PSNR:", round(np.array(t2_PSNR_all).mean(), 4)) + print("t2_NMSE:", round(np.array(t2_MSE_all).mean(), 4)) + print("t2_SSIM:", round(np.array(t2_SSIM_all).mean(), 4)) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/abd_dataset_utils.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/abd_dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..25827ddeb9bb48fa5680b87d111b841ad2ebb892 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/abd_dataset_utils.py @@ -0,0 +1,65 @@ +""" +Utils for datasets +""" +import numpy as np +import os +import sys +import numpy as np +import pdb +import SimpleITK as sitk +from .niftiio import read_nii_bysitk + + +def get_normalize_op(modality, fids): + """ + As title + Args: + modality: CT or MR + fids: fids for the fold + """ + + def get_CT_statistics(scan_fids): + """ + As CT are quantitative, get mean and std for CT images for image normalizing + As in reality we might not be able to load all images at a time, we would better detach statistics calculation with actual data loading + However, in unseen dataset we have no clues about the data statistics at all so just normalize each 3D image to zero mean unit variance + """ + total_val = 0 + n_pix = 0 + for fid in scan_fids: + in_img = read_nii_bysitk(fid) + total_val += in_img.sum() + n_pix += np.prod(in_img.shape) + del in_img + meanval = total_val / n_pix + + total_var = 0 + for fid in scan_fids: + in_img = read_nii_bysitk(fid) + total_var += np.sum((in_img - meanval) ** 2 ) + del in_img + var_all = total_var / n_pix + + global_std = var_all ** 0.5 + + return meanval, global_std + + + if modality == 'SABSCT': + ct_mean, ct_std = get_CT_statistics(fids) + + def CT_normalize(x_in): + """ + Normalizing CT images, based on global statistics + """ + return x_in, ct_mean, ct_std + + return CT_normalize #, {'mean': ct_mean, 'std': ct_std} + + else: # modality == 'CHAOST2' : + + def MR_normalize(x_in): + return x_in, x_in.mean(), x_in.std() + + return MR_normalize #, {'mean': None, 'std': None} # we do not really need the global statistics for MR + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/image_transforms.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/image_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..277bdb02878221816c6a69720a7c98c41bbd2dcb --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/image_transforms.py @@ -0,0 +1,319 @@ +""" +Image transforms functions for data augmentation +Credit to Dr. Jo Schlemper +""" +try: + from collections import Sequence +except: + from collections.abc import Sequence +import cv2 +import numpy as np +import scipy +from scipy.ndimage.filters import gaussian_filter +from scipy.ndimage.interpolation import map_coordinates +from numpy.lib.stride_tricks import as_strided + +###### UTILITIES ###### +def random_num_generator(config, random_state=np.random): + if config[0] == 'uniform': + ret = random_state.uniform(config[1], config[2], 1)[0] + elif config[0] == 'lognormal': + ret = random_state.lognormal(config[1], config[2], 1)[0] + else: + #print(config) + raise Exception('unsupported format') + return ret + +def get_translation_matrix(translation): + """ translation: [tx, ty] """ + tx, ty = translation + translation_matrix = np.array([[1, 0, tx], + [0, 1, ty], + [0, 0, 1]]) + return translation_matrix + + + +def get_rotation_matrix(rotation, input_shape, centred=True): + theta = np.pi / 180 * np.array(rotation) + if centred: + rotation_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), rotation, 1) + rotation_matrix = np.vstack([rotation_matrix, [0, 0, 1]]) + else: + rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1]]) + return rotation_matrix + +def get_zoom_matrix(zoom, input_shape, centred=True): + zx, zy = zoom + if centred: + zoom_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), 0, zoom[0]) + zoom_matrix = np.vstack([zoom_matrix, [0, 0, 1]]) + else: + zoom_matrix = np.array([[zx, 0, 0], + [0, zy, 0], + [0, 0, 1]]) + return zoom_matrix + +def get_shear_matrix(shear_angle): + theta = (np.pi * shear_angle) / 180 + shear_matrix = np.array([[1, -np.sin(theta), 0], + [0, np.cos(theta), 0], + [0, 0, 1]]) + return shear_matrix + +###### AFFINE TRANSFORM ###### +class RandomAffine(object): + """Apply random affine transformation on a numpy.ndarray (H x W x C) + Comment by co1818: this is still doing affine on 2d (H x W plane). + A same transform is applied to all C channels + + Parameter: + ---------- + + alpha: Range [0, 4] seems good for small images + + order: interpolation method (c.f. opencv) + """ + + def __init__(self, + rotation_range=None, + translation_range=None, + shear_range=None, + zoom_range=None, + zoom_keep_aspect=False, + interp='bilinear', + order=3): + """ + Perform an affine transforms. + + Arguments + --------- + rotation_range : one integer or float + image will be rotated randomly between (-degrees, degrees) + + translation_range : (x_shift, y_shift) + shifts in pixels + + *NOT TESTED* shear_range : float + image will be sheared randomly between (-degrees, degrees) + + zoom_range : (zoom_min, zoom_max) + list/tuple with two floats between [0, infinity). + first float should be less than the second + lower and upper bounds on percent zoom. + Anything less than 1.0 will zoom in on the image, + anything greater than 1.0 will zoom out on the image. + e.g. (0.7, 1.0) will only zoom in, + (1.0, 1.4) will only zoom out, + (0.7, 1.4) will randomly zoom in or out + """ + + self.rotation_range = rotation_range + self.translation_range = translation_range + self.shear_range = shear_range + self.zoom_range = zoom_range + self.zoom_keep_aspect = zoom_keep_aspect + self.interp = interp + self.order = order + + def build_M(self, input_shape): + tfx = [] + final_tfx = np.eye(3) + if self.rotation_range: + rot = np.random.uniform(-self.rotation_range, self.rotation_range) + tfx.append(get_rotation_matrix(rot, input_shape)) + if self.translation_range: + tx = np.random.uniform(-self.translation_range[0], self.translation_range[0]) + ty = np.random.uniform(-self.translation_range[1], self.translation_range[1]) + tfx.append(get_translation_matrix((tx,ty))) + if self.shear_range: + rot = np.random.uniform(-self.shear_range, self.shear_range) + tfx.append(get_shear_matrix(rot)) + if self.zoom_range: + sx = np.random.uniform(self.zoom_range[0], self.zoom_range[1]) + if self.zoom_keep_aspect: + sy = sx + else: + sy = np.random.uniform(self.zoom_range[0], self.zoom_range[1]) + + tfx.append(get_zoom_matrix((sx, sy), input_shape)) + + for tfx_mat in tfx: + final_tfx = np.dot(tfx_mat, final_tfx) + + return final_tfx.astype(np.float32) + + def __call__(self, image): + # build matrix + input_shape = image.shape[:2] + M = self.build_M(input_shape) + + res = np.zeros_like(image) + #if isinstance(self.interp, Sequence): + if type(self.order) is list or type(self.order) is tuple: + for i, intp in enumerate(self.order): + res[..., i] = affine_transform_via_M(image[..., i], M[:2], interp=intp) + else: + # squeeze if needed + orig_shape = image.shape + image_s = np.squeeze(image) + res = affine_transform_via_M(image_s, M[:2], interp=self.order) + res = res.reshape(orig_shape) + + #res = affine_transform_via_M(image, M[:2], interp=self.order) + + return res + +def affine_transform_via_M(image, M, borderMode=cv2.BORDER_CONSTANT, interp=cv2.INTER_NEAREST): + imshape = image.shape + shape_size = imshape[:2] + + # Random affine + warped = cv2.warpAffine(image.reshape(shape_size + (-1,)), M, shape_size[::-1], + flags=interp, borderMode=borderMode) + + #print(imshape, warped.shape) + + warped = warped[..., np.newaxis].reshape(imshape) + + return warped + +###### ELASTIC TRANSFORM ###### +def elastic_transform(image, alpha=1000, sigma=30, spline_order=1, mode='nearest', random_state=np.random): + """Elastic deformation of image as described in [Simard2003]_. + .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for + Convolutional Neural Networks applied to Visual Document Analysis", in + Proc. of the International Conference on Document Analysis and + Recognition, 2003. + """ + assert image.ndim == 3 + shape = image.shape[:2] + + dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), + sigma, mode="constant", cval=0) * alpha + dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), + sigma, mode="constant", cval=0) * alpha + + x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij') + indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))] + result = np.empty_like(image) + for i in range(image.shape[2]): + result[:, :, i] = map_coordinates( + image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape) + return result + + +def elastic_transform_nd(image, alpha, sigma, random_state=None, order=1, lazy=False): + """Expects data to be (nx, ny, n1 ,..., nm) + params: + ------ + + alpha: + the scaling parameter. + E.g.: alpha=2 => distorts images up to 2x scaling + + sigma: + standard deviation of gaussian filter. + E.g. + low (sig~=1e-3) => no smoothing, pixelated. + high (1/5 * imsize) => smooth, more like affine. + very high (1/2*im_size) => translation + """ + + if random_state is None: + random_state = np.random.RandomState(None) + + shape = image.shape + imsize = shape[:2] + dim = shape[2:] + + # Random affine + blur_size = int(4*sigma) | 1 + dx = cv2.GaussianBlur(random_state.rand(*imsize)*2-1, + ksize=(blur_size, blur_size), sigmaX=sigma) * alpha + dy = cv2.GaussianBlur(random_state.rand(*imsize)*2-1, + ksize=(blur_size, blur_size), sigmaX=sigma) * alpha + + # use as_strided to copy things over across n1...nn channels + dx = as_strided(dx.astype(np.float32), + strides=(0,) * len(dim) + (4*shape[1], 4), + shape=dim+(shape[0], shape[1])) + dx = np.transpose(dx, axes=(-2, -1) + tuple(range(len(dim)))) + + dy = as_strided(dy.astype(np.float32), + strides=(0,) * len(dim) + (4*shape[1], 4), + shape=dim+(shape[0], shape[1])) + dy = np.transpose(dy, axes=(-2, -1) + tuple(range(len(dim)))) + + coord = np.meshgrid(*[np.arange(shape_i) for shape_i in (shape[1], shape[0]) + dim]) + indices = [np.reshape(e+de, (-1, 1)) for e, de in zip([coord[1], coord[0]] + coord[2:], + [dy, dx] + [0] * len(dim))] + + if lazy: + return indices + + return map_coordinates(image, indices, order=order, mode='reflect').reshape(shape) + +class ElasticTransform(object): + """Apply elastic transformation on a numpy.ndarray (H x W x C) + """ + + def __init__(self, alpha, sigma, order=1): + self.alpha = alpha + self.sigma = sigma + self.order = order + + def __call__(self, image): + if isinstance(self.alpha, Sequence): + alpha = random_num_generator(self.alpha) + else: + alpha = self.alpha + if isinstance(self.sigma, Sequence): + sigma = random_num_generator(self.sigma) + else: + sigma = self.sigma + return elastic_transform_nd(image, alpha=alpha, sigma=sigma, order=self.order) + +class RandomFlip3D(object): + + def __init__(self, h=True, v=True, t=True, p=0.5): + """ + Randomly flip an image horizontally and/or vertically with + some probability. + + Arguments + --------- + h : boolean + whether to horizontally flip w/ probability p + + v : boolean + whether to vertically flip w/ probability p + + p : float between [0,1] + probability with which to apply allowed flipping operations + """ + self.horizontal = h + self.vertical = v + self.depth = t + self.p = p + + def __call__(self, x, y=None): + # horizontal flip with p = self.p + if self.horizontal: + if np.random.random() < self.p: + x = x[::-1, ...] + + # vertical flip with p = self.p + if self.vertical: + if np.random.random() < self.p: + x = x[:, ::-1, ...] + + if self.depth: + if np.random.random() < self.p: + x = x[..., ::-1] + + return x + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/math.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/math.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0d76246b0c8fb757b6b589b7f889e0316696d0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/math.py @@ -0,0 +1,272 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +import numpy as np + +def complex_mul(x, y): + """ + Complex multiplication. + + This multiplies two complex tensors assuming that they are both stored as + real arrays with the last dimension being the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == y.shape[-1] == 2 + re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] + im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] + + return torch.stack((re, im), dim=-1) + + +def complex_conj(x): + """ + Complex conjugate. + + This applies the complex conjugate assuming that the input array has the + last dimension as the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == 2 + + return torch.stack((x[..., 0], -x[..., 1]), dim=-1) + + +# def fft2c(data): +# """ +# Apply centered 2 dimensional Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The FFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# data = torch.fft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + +# def ifft2c(data): +# """ +# Apply centered 2-dimensional Inverse Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The IFFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# # data = torch.ifft(data, 2, normalized=True) +# data = torch.fft.ifft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + + +def fft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2 dimensional Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.fft``. + + Returns: + The FFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.fftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.ifft``. + + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + + + + +def complex_abs(data): + """ + Compute the absolute value of a complex valued input tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Absolute value of data. + """ + assert data.size(-1) == 2 + + return (data ** 2).sum(dim=-1).sqrt() + + + +def complex_abs_numpy(data): + assert data.shape[-1] == 2 + + return np.sqrt(np.sum(data ** 2, axis=-1)) + + +def complex_abs_sq(data):#multi coil + """ + Compute the squared absolute value of a complex tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Squared absolute value of data. + """ + assert data.size(-1) == 2 + return (data ** 2).sum(dim=-1) + + +# Helper functions + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data + """ + data = data.numpy() + return data[..., 0] + 1j * data[..., 1] diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/niftiio.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/niftiio.py new file mode 100644 index 0000000000000000000000000000000000000000..19fce7bc59793d6c2711b497ee01577433788172 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/niftiio.py @@ -0,0 +1,47 @@ +""" +Utils for datasets +""" +import numpy as np +import numpy as np +import SimpleITK as sitk + + +def read_nii_bysitk(input_fid, peel_info = False): + """ read nii to numpy through simpleitk + peelinfo: taking direction, origin, spacing and metadata out + """ + img_obj = sitk.ReadImage(input_fid) + img_np = sitk.GetArrayFromImage(img_obj) + if peel_info: + info_obj = { + "spacing": img_obj.GetSpacing(), + "origin": img_obj.GetOrigin(), + "direction": img_obj.GetDirection(), + "array_size": img_np.shape + } + return img_np, info_obj + else: + return img_np + +def convert_to_sitk(input_mat, peeled_info): + """ + write a numpy array to sitk image object with essential meta-data + """ + nii_obj = sitk.GetImageFromArray(input_mat) + if peeled_info: + nii_obj.SetSpacing( peeled_info["spacing"] ) + nii_obj.SetOrigin( peeled_info["origin"] ) + nii_obj.SetDirection(peeled_info["direction"] ) + return nii_obj + +def np2itk(img, ref_obj): + """ + img: numpy array + ref_obj: reference sitk object for copying information from + """ + itk_obj = sitk.GetImageFromArray(img) + itk_obj.SetSpacing( ref_obj.GetSpacing() ) + itk_obj.SetOrigin( ref_obj.GetOrigin() ) + itk_obj.SetDirection( ref_obj.GetDirection() ) + return itk_obj + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/subsample.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/subsample.py new file mode 100644 index 0000000000000000000000000000000000000000..d0620da3414c6077e4293376fb8a9be01ad19990 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/subsample.py @@ -0,0 +1,195 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/transform_albu.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/transform_albu.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac07fc3237abbf310cd0b088f5b49e1cd042735 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/transform_albu.py @@ -0,0 +1,125 @@ +# -*- encoding: utf-8 -*- +#Time :2022/02/24 18:14:15 +#Author :Hao Chen +#FileName :trans_lib.py +#Version :2.0 + +import cv2 +import torch +import numpy as np +import albumentations as A +def gaussian_noise(img, mean, sigma): + return img + torch.FloatTensor(img.shape).normal_(mean=mean, std=sigma) +# from albumentations.pytorch import ShiftScaleRotate + +def GammaInterference(img): + # Shape Span + gamma = np.random.random() * 1.5 + 0.25 # 0.25 ~ 1.75 + # gamma = np.random.random() * 1.75 + 0.25 # 0.25 ~ 1.75 + img = gamma_concern(img, gamma) # concerntrate + + # Shape Tilt + choose = np.random.randint(0, 2) + direction = np.random.randint(0, 2) + + if choose == 0: + gamma = 0.2 + np.random.random() * 2.3 # 2.5 + img = gamma_power(img, gamma, direction) + else: + gamma = np.random.random() * 2.3 + 0.6 # 1.5 center + img = gamma_exp(img, gamma, direction) + + return img + + + +def get_resize_transforms(img_size = (192, 192)): + # if type == 'train': + return A.Compose([ + A.Resize(img_size[0], img_size[1]) + ], p=1.0, additional_targets={'image2': 'image', "mask2": "mask"}) + + +def get_albu_transforms(type="train", img_size = (192, 192)): + if type == 'train': + compose = [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.ShiftScaleRotate(shift_limit=0.2, scale_limit=(-0.2, 0.2), + rotate_limit=5, p=0.5), + + # A.Defocus(radius=(4, 8), alias_blur=(0.2, 0.4), p=0.5), + # A.GaussNoise(var_limit=(10.0, 25.0), p=0.5), + + # A.GaussianBlur(blur_limit=(3, 7), p=0.5), + # A.Emboss(alpha=(0.5, 1.0), strength=(0.5, 1.0), p=0.5), # Added + + # A.FDA([target_image], p=1, read_fn=lambda x: x) + # A.PixelDistributionAdaptation( reference_images=[reference_image], + + # A.Defocus(radius=(4, 8), alias_blur=(0.2, 0.4), p=0.5) + + # Randomly posterize between 2 and 5 bits + # A.Posterize(num_bits=(4, 6), p=0.5), + + # A.OneOf([ + # A.RandomShadow(p=1.0), + # A.Solarize(p=1.0), + # A.RandomSunFlare(p=1.0), + # ], p=0.5), + + # A.Saturation + # A.HueSaturationValue(hue_shift_limit=0, sat_shift_limit=0, val_shift_limit=5, p=0.5), + # A.RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), + # contrast_limit=(-0.1, 0.1), p=0.5), + # A.MaskDropout(p=0.5), + + A.OneOf([ + A.GridDistortion(num_steps=1, distort_limit=0.3, p=1.0), + A.ElasticTransform(alpha=3, sigma=15, alpha_affine=10, p=1.0) + ], p=0.5), + + A.Resize(img_size[0], img_size[1])] + else: + compose = [A.Resize(img_size[0], img_size[1])] + + return A.Compose(compose, p=1.0, additional_targets={'image2': 'image', "mask2": "mask"}) + + + + +# Beta function +def gamma_concern(img, gamma): + mean = torch.mean(img) + + img = (img - mean) * gamma + img = img + mean + img = torch.clip(img, 0, 1) + + return img + +def gamma_power(img, gamma, direction=0): + if direction == 1: + img = 1 - img + img = torch.pow(img, gamma) + + img = img / torch.max(img) + if direction == 1: + img = 1 - img + + return img + +def gamma_exp(img, gamma, direction=0): + if direction == 1: + img = 1 - img + + img = torch.exp(img * gamma) + img = img / torch.max(img) + + if direction == 1: + img = 1 - img + return img + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/transform_utils.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/transform_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..96f438903c6d476a8150ba1f8d1fe192e00cf5a5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/transform_utils.py @@ -0,0 +1,245 @@ +""" +Utilities for image transforms, part of the code base credits to Dr. Jo Schlemper +""" +from os.path import join +import torch +import numpy as np +import torchvision.transforms as deftfx +from . import image_transforms as myit +import copy +import math +from torchvision.transforms.functional import rotate as torchrotate +from torchvision.transforms.functional import InterpolationMode + +my_augv = { +'flip' : { 'v':False, 'h':False, 't': False, 'p':0.25 }, +'affine' : { + 'rotate':20, + 'shift':(15,15), + 'shear': 20, + 'scale':(0.5, 1.5), +}, +'elastic' : {'alpha':0,'sigma':0}, # medium {'alpha':20,'sigma':5}, +'reduce_2d': True, +'gamma_range': (1.0, 1.0 ), #(0.2, 1.8), +'noise' : { + 'noise_std': 0, # 0.15 + 'clip_pm1': False + }, +'bright_contrast': { + 'contrast': (1.0, 1.0), #(0.60, 1.5), + 'bright': (0, 0)#(-10, 10) + } +} + +tr_aug = { + 'aug': my_augv +} + + +def get_contrast_example(image, random_angle=0, flip=0): + if flip == [3]: + flip = [1, 2] + + # [..., H, W] + image_rotate = torchrotate(image, random_angle, + interpolation=InterpolationMode.BILINEAR) # Bilinear + image_rotate = torch.flip(image_rotate, flip) + + return image_rotate + + + +def get_geometric_transformer(aug, order=3): + affine = aug['aug'].get('affine', 0) + alpha = aug['aug'].get('elastic',{'alpha': 0})['alpha'] + sigma = aug['aug'].get('elastic',{'sigma': 0})['sigma'] + flip = aug['aug'].get('flip', {'v': True, 'h': True, 't': True, 'p':0.125}) + + tfx = [] + if 'flip' in aug['aug']: + tfx.append(myit.RandomFlip3D(**flip)) + + if 'affine' in aug['aug']: + tfx.append(myit.RandomAffine(affine.get('rotate'), + affine.get('shift'), + affine.get('shear'), + affine.get('scale'), + affine.get('scale_iso',True), + order=order)) + + if 'elastic' in aug['aug']: + tfx.append(myit.ElasticTransform(alpha, sigma)) + + input_transform = deftfx.Compose(tfx) + return input_transform + +def get_intensity_transformer(aug): + + def gamma_tansform(img): + gamma_range = aug['aug']['gamma_range'] + if isinstance(gamma_range, tuple): + gamma = np.random.rand() * (gamma_range[1] - gamma_range[0]) + gamma_range[0] + cmin = img.min() + irange = (img.max() - cmin + 1e-5) + + img = img - cmin + 1e-5 + img = irange * np.power(img * 1.0 / irange, gamma) + img = img + cmin + + elif gamma_range == False: + pass + else: + raise ValueError("Cannot identify gamma transform range {}".format(gamma_range)) + return img + + def brightness_contrast(img): + ''' + Chaitanya,K. et al. Semi-Supervised and Task-Driven data augmentation,864in: International Conference on Information Processing in Medical Imaging,865Springer. pp. 29–41. + ''' + cmin, cmax = aug['aug']['bright_contrast']['contrast'] + bmin, bmax = aug['aug']['bright_contrast']['bright'] + c = np.random.rand() * (cmax - cmin) + cmin + b = np.random.rand() * (bmax - bmin) + bmin + img_mean = img.mean() + img = (img - img_mean) * c + img_mean + b + return img + + def zm_gaussian_noise(img): + """ + zero-mean gaussian noise + """ + noise_sigma = aug['aug']['noise']['noise_std'] + noise_vol = np.random.randn(*img.shape) * noise_sigma + img = img + noise_vol + + if aug['aug']['noise']['clip_pm1']: # if clip to plus-minus 1 + img = np.clip(img, -1.0, 1.0) + return img + + def compile_transform(img): + # bright contrast + if 'bright_contrast' in aug['aug'].keys(): + img = brightness_contrast(img) + + # gamma + if 'gamma_range' in aug['aug'].keys(): + img = gamma_tansform(img) + + # additive noise + if 'noise' in aug['aug'].keys(): + img = zm_gaussian_noise(img) + + return img + + return compile_transform + + +def transform_with_label(aug, add_pseudolabel = False): + """ + Doing image geometric transform + Proposed image to have the following configurations + [H x W x C + CL] + Where CL is the number of channels for the label. It is NOT a one-hot thing + """ + + geometric_tfx = get_geometric_transformer(aug) + intensity_tfx = get_intensity_transformer(aug) + + def transform(comp, c_label, c_img, c_sam, nclass, is_train, use_onehot = False): + """ + Args + comp: a numpy array with shape [H x W x C + c_label] + c_label: number of channels for a compact label. Note that the current version only supports 1 slice (H x W x 1) + nc_onehot: -1 for not using one-hot representation of mask. otherwise, specify number of classes in the label + is_train: whether this is the training set or not. If not, do not perform the geometric transform + """ + comp = copy.deepcopy(comp) + if (use_onehot is True) and (c_label != 1): + raise NotImplementedError("Only allow compact label, also the label can only be 2d") + assert c_img + c_sam + c_label == comp.shape[-1], "only allow single slice 2D label" + + if is_train is True: + _label = comp[..., c_img ] + _sam = np.expand_dims(comp[..., c_img+c_label], axis=-1) + # compact to onehot + _h_label = np.float32(np.arange( nclass ) == (_label[..., None]) ) + # print("h_label=", _h_label.shape) + # print("_sam=", _sam.shape) + + comp = np.concatenate( [comp[..., :c_img ], _h_label, _sam], -1 ) + comp = geometric_tfx(comp) + # round one_hot labels to 0 or 1 + t_label_h = comp[..., c_img : -c_sam] + t_label_h = np.rint(t_label_h) + t_img = comp[..., 0 : c_img ] + t_sam = np.rint(comp[..., -c_sam:]) + + # intensity transform + t_img = intensity_tfx(t_img) + + if use_onehot is True: + t_label = t_label_h + else: + t_label = np.expand_dims(np.argmax(t_label_h, axis = -1), -1) + return t_img, t_label, t_sam + + return transform + + + + +def gamma_concern(img, gamma): + mean = np.mean(img) + + img = (img - mean) * gamma + img = img + mean + img = np.clip(img, 0, 1) + + return img + +def gamma_power(img, gamma, direction=0): + if direction == 1: + img = 1 - img + img = np.power(img, gamma) + + img = img / np.max(img) + if direction == 1: + img = 1 - img + + return img + +def gamma_exp(img, gamma, direction=0): + if direction == 1: + img = 1 - img + + img = np.exp(img * gamma) + img = img / np.max(img) + + if direction == 1: + img = 1 - img + return img + + +def GammaInterference(img): + # Shape Span + gamma = np.random.random() * 1.5 + 0.25 # 0.25 ~ 1.75 + # gamma = np.random.random() * 1.75 + 0.25 # 0.25 ~ 1.75 + img = gamma_concern(img, gamma) # concerntrate + + # Shape Tilt + choose = np.random.randint(0, 2) + direction = np.random.randint(0, 2) + + if choose == 0: + gamma = 0.2 + np.random.random() * 2.3 # 2.5 + img = gamma_power(img, gamma, direction) + else: + gamma = np.random.random() * 2.3 + 0.6 # 1.5 center + img = gamma_exp(img, gamma, direction) + + return img + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/transforms.py b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..ec304cf13a66ee181491abe7d9adbd31b16e3f4b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/dataset/utils/transforms.py @@ -0,0 +1,487 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import numpy as np +import torch + +from dataset.m4_utils.math import ifft2c, fft2c, complex_abs +from dataset.m4_utils.subsample import create_mask_for_mask_type, MaskFunc +import random + +from typing import Dict, Optional, Sequence, Tuple, Union +from matplotlib import pyplot as plt +import os + +def rss(data, dim=0): + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Args: + data (torch.Tensor): The input tensor + dim (int): The dimensions along which to apply the RSS transform + + Returns: + torch.Tensor: The RSS value. + """ + return torch.sqrt((data ** 2).sum(dim)) + + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Args: + data (np.array): Input numpy array. + + Returns: + torch.Tensor: PyTorch version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data. + """ + data = data.numpy() + + return data[..., 0] + 1j * data[..., 1] + + +def apply_mask(data, mask_func, seed=None, padding=None): + """ + Subsample given k-space by multiplying with a mask. + + Args: + data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where + dimensions -3 and -2 are the spatial dimensions, and the final dimension has size + 2 (for complex values). + mask_func (callable): A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed (int or 1-d array_like, optional): Seed for the random number generator. + + Returns: + (tuple): tuple containing: + masked data (torch.Tensor): Subsampled k-space data + mask (torch.Tensor): The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + + +def mask_center(x, mask_from, mask_to): + mask = torch.zeros_like(x) + mask[:, :, :, mask_from:mask_to] = x[:, :, :, mask_from:mask_to] + + return mask + + +def center_crop(data, shape): + """ + Apply a center crop to the input real image or batch of real images. + + Args: + data (torch.Tensor): The input tensor to be center cropped. It should + have at least 2 dimensions and the cropping is applied along the + last two dimensions. + shape (int, int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image. + """ + assert 0 < shape[0] <= data.shape[-2] + assert 0 < shape[1] <= data.shape[-1] + + w_from = (data.shape[-2] - shape[0]) // 2 + h_from = (data.shape[-1] - shape[1]) // 2 + w_to = w_from + shape[0] + h_to = h_from + shape[1] + + return data[..., w_from:w_to, h_from:h_to] + + +def complex_center_crop(data, shape): + """ + Apply a center crop to the input image or batch of complex images. + + Args: + data (torch.Tensor): The complex input tensor to be center cropped. It + should have at least 3 dimensions and the cropping is applied along + dimensions -3 and -2 and the last dimensions should have a size of + 2. + shape (int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image + """ + assert 0 < shape[0] <= data.shape[-3] + assert 0 < shape[1] <= data.shape[-2] + + w_from = (data.shape[-3] - shape[0]) // 2 #80 + h_from = (data.shape[-2] - shape[1]) // 2 #80 + w_to = w_from + shape[0] #240 + h_to = h_from + shape[1] #240 + + return data[..., w_from:w_to, h_from:h_to, :] + + +def center_crop_to_smallest(x, y): + """ + Apply a center crop on the larger image to the size of the smaller. + + The minimum is taken over dim=-1 and dim=-2. If x is smaller than y at + dim=-1 and y is smaller than x at dim=-2, then the returned dimension will + be a mixture of the two. + + Args: + x (torch.Tensor): The first image. + y (torch.Tensor): The second image + + Returns: + tuple: tuple of tensors x and y, each cropped to the minimim size. + """ + smallest_width = min(x.shape[-1], y.shape[-1]) + smallest_height = min(x.shape[-2], y.shape[-2]) + x = center_crop(x, (smallest_height, smallest_width)) + y = center_crop(y, (smallest_height, smallest_width)) + + return x, y + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + +class DataTransform(object): + """ + Data Transformer for training U-Net models. + """ + + def __init__(self, which_challenge): + """ + Args: + which_challenge (str): Either "singlecoil" or "multicoil" denoting + the dataset. + mask_func (fastmri.data.subsample.MaskFunc): A function that can + create a mask of appropriate shape. + use_seed (bool): If true, this class computes a pseudo random + number generator seed from the filename. This ensures that the + same mask is used for all the slices of a given volume every + time. + """ + if which_challenge not in ("singlecoil", "multicoil"): + raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"') + + self.which_challenge = which_challenge + + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + """ + Args: + kspace (numpy.array): Input k-space of shape (num_coils, rows, + cols, 2) for multi-coil data or (rows, cols, 2) for single coil + data. + mask (numpy.array): Mask from the test dataset. + target (numpy.array): Target image. + attrs (dict): Acquisition related information stored in the HDF5 + object. + fname (str): File name. + slice_num (int): Serial number of the slice. + + Returns: + (tuple): tuple containing: + image (torch.Tensor): Zero-filled input image. + target (torch.Tensor): Target image converted to a torch + Tensor. + mean (float): Mean value used for normalization. + std (float): Standard deviation value used for normalization. + fname (str): File name. + slice_num (int): Serial number of the slice. + """ + kspace = to_tensor(kspace) + + # inverse Fourier transform to get zero filled solution + image = ifft2c(kspace) + + # crop input to correct size + if target is not None: + crop_size = (target.shape[-2], target.shape[-1]) + else: + crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) + + # check for sFLAIR 203 + if image.shape[-2] < crop_size[1]: + crop_size = (image.shape[-2], image.shape[-2]) + + image = complex_center_crop(image, crop_size) + + # getLR + imgfft = fft2c(image) + imgfft = complex_center_crop(imgfft, (160, 160)) + LR_image = ifft2c(imgfft) + + # absolute value + LR_image = complex_abs(LR_image) + + # normalize input + LR_image, mean, std = normalize_instance(LR_image, eps=1e-11) + LR_image = LR_image.clamp(-6, 6) + + # normalize target + if target is not None: + target = to_tensor(target) + target = center_crop(target, crop_size) + target = normalize(target, mean, std, eps=1e-11) + target = target.clamp(-6, 6) + else: + target = torch.Tensor([0]) + + return LR_image, target, mean, std, fname, slice_num + +class DenoiseDataTransform(object): + def __init__(self, size, noise_rate): + super(DenoiseDataTransform, self).__init__() + self.size = (size, size) + self.noise_rate = noise_rate + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + max_value = attrs["max"] + + #target + target = to_tensor(target) + target = center_crop(target, self.size) + target, mean, std = normalize_instance(target, eps=1e-11) + target = target.clamp(-6, 6) + + #image + kspace = to_tensor(kspace) + complex_image = ifft2c(kspace) #complex_image + image = complex_center_crop(complex_image, self.size) + noise_image = self.rician_noise(image, max_value) + noise_image = complex_abs(noise_image) + + noise_image = normalize(noise_image, mean, std, eps=1e-11) + noise_image = noise_image.clamp(-6, 6) + + return noise_image, target, mean, std, fname, slice_num + + + def rician_noise(self, X, noise_std): + #Add rician noise with variance sampled uniformly from the range 0 and 0.1 + noise_std = random.uniform(0, noise_std*self.noise_rate) + Ir = X + noise_std * torch.randn(X.shape) + Ii = noise_std*torch.randn(X.shape) + In = torch.sqrt(Ir ** 2 + Ii ** 2) + return In + + +def apply_mask( + data: torch.Tensor, + mask_func: MaskFunc, + seed: Optional[Union[int, Tuple[int, ...]]] = None, + padding: Optional[Sequence[int]] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Subsample given k-space by multiplying with a mask. + Args: + data: The input k-space data. This should have at least 3 dimensions, + where dimensions -3 and -2 are the spatial dimensions, and the + final dimension has size 2 (for complex values). + mask_func: A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed: Seed for the random number generator. + padding: Padding value to apply for mask. + Returns: + tuple containing: + masked data: Subsampled k-space data + mask: The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + + + +class ReconstructionTransform(object): + """ + Data Transformer for training U-Net models. + """ + + def __init__(self, which_challenge, mask_func=None, use_seed=True): + """ + Args: + which_challenge (str): Either "singlecoil" or "multicoil" denoting + the dataset. + mask_func (fastmri.data.subsample.MaskFunc): A function that can + create a mask of appropriate shape. + use_seed (bool): If true, this class computes a pseudo random + number generator seed from the filename. This ensures that the + same mask is used for all the slices of a given volume every + time. + """ + if which_challenge not in ("singlecoil", "multicoil"): + raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"') + + self.mask_func = mask_func + self.which_challenge = which_challenge + self.use_seed = use_seed + + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + """ + Args: + kspace (numpy.array): Input k-space of shape (num_coils, rows, + cols, 2) for multi-coil data or (rows, cols, 2) for single coil + data. + mask (numpy.array): Mask from the test dataset. + target (numpy.array): Target image. + attrs (dict): Acquisition related information stored in the HDF5 + object. + fname (str): File name. + slice_num (int): Serial number of the slice. + + Returns: + (tuple): tuple containing: + image (torch.Tensor): Zero-filled input image. + target (torch.Tensor): Target image converted to a torch + Tensor. + mean (float): Mean value used for normalization. + std (float): Standard deviation value used for normalization. + fname (str): File name. + slice_num (int): Serial number of the slice. + """ + kspace = to_tensor(kspace) + + # apply mask + if self.mask_func: + seed = None if not self.use_seed else tuple(map(ord, fname)) + masked_kspace, mask = apply_mask(kspace, self.mask_func, seed) + else: + masked_kspace = kspace + + # inverse Fourier transform to get zero filled solution + image = ifft2c(masked_kspace) + + # crop input to correct size + if target is not None: + crop_size = (target.shape[-2], target.shape[-1]) + else: + crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) + + # check for sFLAIR 203 + if image.shape[-2] < crop_size[1]: + crop_size = (image.shape[-2], image.shape[-2]) + + image = complex_center_crop(image, crop_size) + # print('image',image.shape) + # absolute value + image = complex_abs(image) + + # apply Root-Sum-of-Squares if multicoil data + if self.which_challenge == "multicoil": + image = rss(image) + + # normalize input + image, mean, std = normalize_instance(image, eps=1e-11) + image = image.clamp(-6, 6) + + # normalize target + if target is not None: + target = to_tensor(target) + target = center_crop(target, crop_size) + target = normalize(target, mean, std, eps=1e-11) + target = target.clamp(-6, 6) + else: + target = torch.Tensor([0]) + + return image, target, mean, std, fname, slice_num + + +def build_transforms(MASKTYPE, CENTER_FRACTIONS, ACCELERATIONS, mode = 'train'): + + challenge = 'singlecoil' + return ReconstructionTransform(challenge) + + # if mode == 'train': + # mask = create_mask_for_mask_type( + # MASKTYPE, CENTER_FRACTIONS, ACCELERATIONS, + # ) + # return ReconstructionTransform(challenge, mask, use_seed=False) + # + # elif mode == 'val' or mode == 'test': + # mask = create_mask_for_mask_type( + # MASKTYPE, CENTER_FRACTIONS, ACCELERATIONS, + # ) + # return ReconstructionTransform(challenge, mask) + # + # else: + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..946aef9fd093eb73d770679740b159912e0dad4d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/__init__.py @@ -0,0 +1,2 @@ +from diffusion_pytorch.diffusion_gaussian import GaussianDiffusion, Trainer +# from diffusion_pytorch.new_twobranch_model import Model as TwoBranchNewModel diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/brats_mask.py b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/brats_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..2a70094da5195c6a78ed118bc5e8b352adf1b5b6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/brats_mask.py @@ -0,0 +1,305 @@ +""" +Xiaohan Xing, 2023/04/08 +对BRATS 2020数据集进行Pre-processing, 得到各个模态的under-sampled input image和2d groung-truth. +""" +import os +import argparse +import numpy as np +import nibabel as nib +from scipy import ndimage as nd +from scipy import ndimage +from skimage import filters +from skimage import io +import torch +import torch.fft +from matplotlib import pyplot as plt + +MRIDOWN=8 +SNR = 0 + + +class MaskFunc_Cartesian: + """ + MaskFunc creates a sub-sampling mask of a given shape. + The mask selects a subset of columns from the input k-space data. If the k-space data has N + columns, the mask picks out: + a) N_low_freqs = (N * center_fraction) columns in the center corresponding to + low-frequencies + b) The other columns are selected uniformly at random with a probability equal to: + prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). + This ensures that the expected number of columns selected is equal to (N / acceleration) + It is possible to use multiple center_fractions and accelerations, in which case one possible + (center_fraction, acceleration) is chosen uniformly at random each time the MaskFunc object is + called. + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there + is a 50% probability that 4-fold acceleration with 8% center fraction is selected and a 50% + probability that 8-fold acceleration with 4% center fraction is selected. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be retained. + If multiple values are provided, then one of these numbers is chosen uniformly + each time. + accelerations (List[int]): Amount of under-sampling. This should have the same length + as center_fractions. If multiple values are provided, then one of these is chosen + uniformly each time. An acceleration of 4 retains 25% of the columns, but they may + not be spaced evenly. + """ + if len(center_fractions) != len(accelerations): + raise ValueError('Number of center fractions should match number of accelerations') + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random.RandomState() + + def __call__(self, shape, seed=None): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The shape should have + at least 3 dimensions. Samples are drawn along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting the seed + ensures the same mask is generated each time for the same shape. + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError('Shape should have 3 or more dimensions') + + self.rng.seed(seed) + num_cols = shape[-2] + + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + # Create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs + 1e-10) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad:pad + num_low_freqs] = True + + # Reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + mask = mask.repeat(shape[0], 1, 1) + + return mask + + +## mri related +def mri_fourier_transform_2d(image, mask): + ''' + image: input tensor [B, H, W, C] + mask: mask tensor [H, W] + ''' + spectrum = torch.fft.fftn(image, dim=(1, 2)) + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + spectrum = torch.fft.fftshift(spectrum, dim=(1, 2)) + # Downsample k-space + spectrum = spectrum * mask[None, :, :, None] + return spectrum + +## mri related +def mri_inver_fourier_transform_2d(spectrum): + ''' + image: input tensor [B, H, W, C] + ''' + spectrum = torch.fft.ifftshift(spectrum, dim=(1, 2)) + image = torch.fft.ifftn(spectrum, dim=(1, 2)) + + return image + + +def get_undersample(): + ff = MaskFunc_Cartesian([0.2], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4 + shape = [240, 240, 1] + mask = ff(shape, seed=1337) + mask = mask[:, :, 0] + + + plt.imshow(mask) + plt.show() + + +def simulate_undersample_mri(raw_mri): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + ff = MaskFunc_Cartesian([0.2], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4 + shape = [240, 240, 1] + mask = ff(shape, seed=1337) + mask = mask[:, :, 0] + # print("original MRI:", mri) + + # print("original MRI:", mri.shape) + kspace = mri_fourier_transform_2d(mri, mask) + kspace = add_gaussian_noise(kspace) + mri_recon = mri_inver_fourier_transform_2d(kspace) + kdata = torch.sqrt(kspace.real ** 2 + kspace.imag ** 2 + 1e-10) + kdata = kdata.data.numpy()[0, :, :, 0] + + under_img = torch.sqrt(mri_recon.real ** 2 + mri_recon.imag ** 2) + under_img = under_img.data.numpy()[0, :, :, 0] + + return under_img, kspace + + +def add_gaussian_noise(img, snr=15): + ### 根据SNR确定noise的放大比例 + num_pixels = img.shape[0]*img.shape[1]*img.shape[2]*img.shape[3] + psr = torch.sum(torch.abs(img.real)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + noise_r = torch.randn_like(img.real)*np.sqrt(pnr) + + psim = torch.sum(torch.abs(img.imag)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = torch.randn_like(img.imag)*np.sqrt(pnim) + + noise = noise_r + 1j*noise_im + noise_img = img + noise + # print("original image:", img) + # print("gaussian noise:", noise) + + return noise_img + + +def complexsing_addnoise(img, snr): + ### add noise to the real part of the image. + img_numpy = img.cpu().numpy() + # print("kspace data:", img) + s_r = np.real(img_numpy) + num_pixels = s_r.shape[0]*s_r.shape[1]*s_r.shape[2]*s_r.shape[3] + psr = np.sum(np.abs(s_r)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + # print("PSR:", psr, "PNR:", pnr) + noise_r = np.random.randn(num_pixels)*np.sqrt(pnr) + + ### add noise to the iamginary part of the image. + s_im = np.imag(img_numpy) + psim = np.sum(np.abs(s_im)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = np.random.randn(num_pixels)*np.sqrt(pnim) + + noise = torch.Tensor(noise_r) + 1j*torch.Tensor(noise_im) + sn = img + noise + # print("noisy data:", sn) + # sn = torch.Tensor(sn) + + return sn + + + +def _parse(rootdir): + filetree = {} + + for sample_file in os.listdir(rootdir): + sample_dir = rootdir + sample_file + subject = sample_file + + for filename in os.listdir(sample_dir): + modality = filename.split('.').pop(0).split('_')[-1] + + if subject not in filetree: + filetree[subject] = {} + filetree[subject][modality] = filename + + return filetree + + + +def clean(rootdir, savedir, source_modality, target_modality): + filetree = _parse(rootdir) + print("filetree:", filetree) + + if not os.path.exists(savedir+'/img_norm'): + os.makedirs(savedir+'/img_norm') + + for subject, modalities in filetree.items(): + print(f'{subject}:') + + if source_modality not in modalities or target_modality not in modalities: + print('-> incomplete') + continue + + source_path = os.path.join(rootdir, subject, modalities[source_modality]) + target_path = os.path.join(rootdir, subject, modalities[target_modality]) + + source_image = nib.load(source_path) + target_image = nib.load(target_path) + + source_volume = source_image.get_fdata() + target_volume = target_image.get_fdata() + source_binary_volume = np.zeros_like(source_volume) + target_binary_volume = np.zeros_like(target_volume) + + print("source volume:", source_volume.shape) + print("target volume:", target_volume.shape) + + for i in range(source_binary_volume.shape[-1]): + source_slice = source_volume[:, :, i] + target_slice = target_volume[:, :, i] + + if source_slice.min() == source_slice.max(): + print("invalide source slice") + source_binary_volume[:, :, i] = np.zeros_like(source_slice) + else: + source_binary_volume[:, :, i] = ndimage.morphology.binary_fill_holes( + source_slice > filters.threshold_li(source_slice)) + + if target_slice.min() == target_slice.max(): + print("invalide target slice") + target_binary_volume[:, :, i] = np.zeros_like(target_slice) + else: + target_binary_volume[:, :, i] = ndimage.morphology.binary_fill_holes( + target_slice > filters.threshold_li(target_slice)) + + source_volume = np.where(source_binary_volume, source_volume, np.ones_like( + source_volume) * source_volume.min()) + target_volume = np.where(target_binary_volume, target_volume, np.ones_like( + target_volume) * target_volume.min()) + ## resize + if source_image.header.get_zooms()[0] < 0.6: + scale = np.asarray([240, 240, source_volume.shape[-1]]) / np.asarray(source_volume.shape) + source_volume = nd.zoom(source_volume, zoom=scale, order=3, prefilter=False) + target_volume = nd.zoom(target_volume, zoom=scale, order=0, prefilter=False) + + # save volume into images + source_volume = (source_volume-source_volume.min())/(source_volume.max()-source_volume.min()) + target_volume = (target_volume-target_volume.min())/(target_volume.max()-target_volume.min()) + + for i in range(source_binary_volume.shape[-1]): + source_binary_slice = source_binary_volume[:, :, i] + target_binary_slice = target_binary_volume[:, :, i] + if source_binary_slice.max() > 0 and target_binary_slice.max() > 0: + dd = target_volume.shape[0] // 2 + target_slice = target_volume[dd - 120:dd + 120, dd - 120:dd + 120, i] + source_slice = source_volume[dd - 120:dd + 120, dd - 120:dd + 120, i] + print("source slice range:", source_slice.shape) + print("target slice range:", target_slice.max(), target_slice.min()) + # undersample MRI + source_under_img, source_kspace = simulate_undersample_mri(source_slice) + target_under_img, target_kspace = simulate_undersample_mri(target_slice) + + # # io.imsave(savedir+'/img_norm/'+subject+'_'+str(i)+'_'+source_modality+'.png', (source_slice * 255.0).astype(np.uint8)) + io.imsave(savedir + '/img_temp/' + subject + '_' + str(i) + '_' + source_modality + '_' + str(MRIDOWN) + 'X_' + str(SNR) + 'dB_undermri.png', + (source_under_img * 255.0).astype(np.uint8)) + + # io.imsave(savedir + '/img_temp/' + subject + '_' + str(i) + '_' + source_modality + '_' + str(MRIDOWN) + 'X_undermri.png', + # (source_under_img * 255.0).astype(np.uint8)) + # # io.imsave(savedir+'/img_norm/'+subject+'_'+str(i)+'_'+target_modality+'.png', (target_slice * 255.0).astype(np.uint8)) + # io.imsave(savedir + '/img_norm/' + subject + '_' + str(i) + '_' + target_modality + '_' + str(MRIDOWN) + 'X_undermri.png', + # (target_under_img * 255.0).astype(np.uint8)) + + # np.savez_compressed(rootdir + '/img_norm/' + subject + '_' + str(i) + '_' + target_modality + '_raw_'+str(MRIDOWN)+'X'+str(CTNVIEW)+'P', + # kspace=kspace, under_t1=under_img, + # t1=source_slice, ct=target_slice) + + +def main(args): + clean(args.rootdir,args.savedir, args.source, args.target) + + +if __name__ == '__main__': + get_undersample() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/example_mask/brats_4X_mask.npy b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/example_mask/brats_4X_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..bdf32304f95640286541ceb1068582dc69b0d60a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/example_mask/brats_4X_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76341ba680a0bc9c80389e01f8511e5bd99ab361eeb48d83516904b84cccc518 +size 460928 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/example_mask/brats_8X_mask.npy b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/example_mask/brats_8X_mask.npy new file mode 100644 index 0000000000000000000000000000000000000000..c389e708adeb3307db90ff071599256b8f59dab5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/example_mask/brats_8X_mask.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c5160add079e8f4dc2496e5ef87c110015026d9f6116329da2238a73d8bc104 +size 230528 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/extract_example_mask.py b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/extract_example_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..630c51d74f70c00fba605fd06761eb9a73c9d3e9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/extract_example_mask.py @@ -0,0 +1,80 @@ +import matplotlib.pyplot as plt +import torch +import numpy as np +from torch.fft import fft2, ifft2, fftshift, ifftshift + +# brats 4X +example = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_4X/BraTS20_Training_036_90_t2_4X_undermri.png" +gt = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_4X/BraTS20_Training_036_90_t2.png" +save_file = "./example_mask/brats_4X_mask.npy" + + +# example = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_8X/BraTS20_Training_036_90_t2_8X_undermri.png" +# gt = "/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_8X/BraTS20_Training_036_90_t2.png" +# save_file = "./example_mask/brats_8X_mask.npy" + + +example_img = plt.imread(example) # cv2.imread(example_frequency_img, cv2.IMREAD_GRAYSCALE) +gt = plt.imread(gt) # cv2.imread(example_frequency_img, cv2.IMREAD_GRAYSCALE) + +print("example_img shape: ", example_img.shape) +plt.imshow(example_img, cmap='gray') +plt.title("Example Frequency Image") +plt.show() + +example_img = torch.from_numpy(example_img).float() +fre = fftshift(fft2(example_img)) # ) + +amp = torch.log(torch.abs(fre)) +angle = torch.angle(fre) + +plt.imshow(amp.squeeze(0).squeeze(0).numpy()) +plt.show() + +plt.imshow(angle.squeeze(0).squeeze(0).numpy()) +plt.show() + + +gt_fre = fftshift(fft2(torch.from_numpy(gt).float())) # ) +gt_amp = torch.log(torch.abs(gt_fre)) +plt.imshow(gt_amp.squeeze(0).squeeze(0).numpy()) +plt.show() +gt_angle = torch.angle(gt_fre) +plt.imshow(gt_angle.squeeze(0).squeeze(0).numpy()) +plt.show() + + +amp_mask = gt_amp.squeeze(0).squeeze(0).numpy() - amp.squeeze(0).squeeze(0).numpy() +amp_mask = np.mean(amp_mask, axis=0, keepdims=True) + +thres = np.mean(amp_mask) * 0.73 + + +new_mask = (amp_mask < thres) * 1.0 +new_mask = np.repeat(new_mask, 240, axis=0) + +amp_mask[amp_mask < thres] = 1 +amp_mask[amp_mask >= thres] = 0 + + +#duplicate +amp_mask = np.repeat(amp_mask, 240, axis=0) + +plt.imshow(gt_amp.squeeze(0).squeeze(0).numpy() - amp.squeeze(0).squeeze(0).numpy()) +plt.show() +plt.imshow(gt_angle.squeeze(0).squeeze(0).numpy() - angle.squeeze(0).squeeze(0).numpy()) +plt.show() +plt.imshow(new_mask) +plt.show() + +np.save(save_file, new_mask) + + + +load_backmask = np.load(save_file) +plt.imshow(load_backmask) +plt.show() + +size = load_backmask.shape[0] * load_backmask.shape[1] +print("shape=", size, load_backmask.sum()/size) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/k_degradation.py b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/k_degradation.py new file mode 100644 index 0000000000000000000000000000000000000000..6c74f69a9437565dede0875af404e234819445af --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/k_degradation.py @@ -0,0 +1,512 @@ +# --------------------------- +# Fade kernels +# --------------------------- +import cv2, torch +import numpy as np +import torchgeometry as tgm +from torch.fft import fft2, ifft2, fftshift, ifftshift +try: + from diffusion_pytorch.degradation.mask_utils import RandomMaskFunc, EquispacedMaskFractionFunc, EquiSpacedMaskFunc, RandomPatchFunc +except: + from mask_utils import RandomMaskFunc, EquispacedMaskFractionFunc, EquiSpacedMaskFunc, RandomPatchFunc + +from torch import nn +import matplotlib.pyplot as plt + +def get_fade_kernel(dims, std): + fade_kernel = tgm.image.get_gaussian_kernel2d(dims, std) + fade_kernel = fade_kernel / torch.max(fade_kernel) + fade_kernel = torch.ones_like(fade_kernel) - fade_kernel + # if device_of_kernel == 'cuda': + # fade_kernel = fade_kernel.cuda() + fade_kernel = fade_kernel[1:, 1:] + return fade_kernel + + + +def get_fade_kernels(fade_routine, num_timesteps, image_size, kernel_std,initial_mask): + kernels = [] + for i in range(num_timesteps): + if fade_routine == 'Incremental': + kernels.append(get_fade_kernel((image_size + 1, image_size + 1), + (kernel_std * (i + initial_mask), + kernel_std * (i + initial_mask)))) + elif fade_routine == 'Constant': + kernels.append(get_fade_kernel( + (image_size + 1, image_size + 1), + (kernel_std, kernel_std))) + + elif fade_routine == 'Random_Incremental': + kernels.append(get_fade_kernel((2 * image_size + 1, 2 * image_size + 1), + (kernel_std * (i + initial_mask), + kernel_std * (i + initial_mask)))) + return torch.stack(kernels) + + +# --------------------------- +# Kspace kernels +# --------------------------- +# cartesian_regular +def get_mask_func(mask_method, af, cf): + if mask_method == 'cartesian_regular': + return EquispacedMaskFractionFunc(center_fractions=[cf], accelerations=[af]) + + elif mask_method == 'cartesian_random': + return RandomMaskFunc(center_fractions=[cf], accelerations=[af]) + + elif mask_method == "random": + return RandomMaskFunc([cf], [af]) + + elif mask_method == "randompatch": + return RandomPatchFunc([cf], [af]) + + elif mask_method == "equispaced": + return EquiSpacedMaskFunc([cf], [af]) + + else: + raise NotImplementedError + + +use_fix_center_ratio = False + +class Noisy_Patch(nn.Module): + def __init__(self): + super(Noisy_Patch, self).__init__() + self.af_list = [] + self.cf_list = [] + self.fe_list = [] + self.pe_list = [] + self.seed = 0 + + def append_list(self, at, cf, fe, pe): + self.af_list.append(at) + self.cf_list.append(cf) + self.fe_list.append(fe) + self.pe_list.append(pe) + + def get_noisy_patches(self, t): + af = self.af_list[t] + cf = self.cf_list[t] + fe = self.fe_list[t] + pe = self.pe_list[t] + + patch_mask = get_mask_func("randompatch", af, cf) + mask_, _ = patch_mask((fe, pe, 1), seed=self.seed) # mask (numpy): (fe, pe) + return mask_ + + def forward(self, mask, ts): + # Step 1, Random Drop elements to learn intra-stride effects, patch-wise + # print("use_patch_kernel forward:", t) + # print("mask = ", mask.shape) + # masks_ = [] + for id, t in enumerate(ts): + mask_ = self.get_noisy_patches(t)[0] + # print("mask_ = ", mask_.shape) + # print("mask[id, t] =", mask[t].shape) + + mask[t] = mask_.to(mask[t].device) * mask[t] + self.seed += ts[0].item() + + # masks_ = torch.stack(masks_).cuda() + # print("masks_ = ", masks_.shape) + # print("mask = ", mask.shape) # B, T, H, W + + return mask + +get_noisy_patches = Noisy_Patch() + + +def get_ksu_mask(mask_method, af, cf, pe, fe, seed=0, is_training=False): + mask_func = get_mask_func(mask_method, af, cf) # acceleration factor, center fraction + + if mask_method in ['cartesian_regular', 'cartesian_random', 'equispaced']: + + mask, num_low_freq = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + # prepare noisy patches + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, fe, pe) + + elif mask_method == 'equispaced': + mask, num_low_freq = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + # prepare noisy patches + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, fe, pe) + + elif mask_method == 'gaussian_2d': + mask, _ = mask_func(resolution=(fe, pe), accel=af, sigma=100, seed=seed) # mask (numpy): (fe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (fe, pe) --> (1, fe, pe) + + elif mask_method in ['radial_add', 'radial_sub', 'spiral_add', 'spiral_sub']: + sr = 1 / af + mask = mask_func(mask_type=mask_method, mask_sr=sr, res=pe, seed=seed) # mask (numpy): (pe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (pe, pe) --> (1, pe, pe) + + else: + raise NotImplementedError + + # print("return mask = ", mask.shape) + return mask + + + +def get_ksu_kernel(timesteps, image_size, + ksu_routine="LogSamplingRate", + mask_method="cartesian_random", + accelerated_factor=4, is_training = False, example_frequency_img=None): + + + if accelerated_factor == 4: + if is_training: + mask_method, center_fraction = "cartesian_random", 0.08 # 0.15 + else: + mask_method, center_fraction = "cartesian_random", 0.08 # 0.08 + + elif accelerated_factor == 8: + if is_training: + mask_method, center_fraction = "equispaced", 0.04 # 0.04 + else: + mask_method, center_fraction = "equispaced", 0.04 + + + center_ratio_factor = center_fraction * accelerated_factor + + masks = [] + noisy_masks = [] + ksu_mask_pe = ksu_mask_fe = image_size # , ksu_mask_pe=320, ksu_mask_fe=320 + # ksu_mask_fe + if ksu_routine == 'LinearSamplingRate': + # Generate the sampling rate list with torch.linspace, reversed, and skip the first element + sr_list = torch.linspace(start=1/accelerated_factor, end=1, steps=timesteps + 1).flip(0) + # Start from 0.01 + for sr in sr_list: + af = 1 / sr # * accelerated_factor # acceleration factor + cf = center_fraction if use_fix_center_ratio else sr_list[0] * center_ratio_factor + + masks.append(get_ksu_mask(mask_method, af, cf, pe=ksu_mask_pe, fe=ksu_mask_fe, is_training=is_training)) + + elif ksu_routine == 'LogSamplingRate': + # MRI-Specific Masking: + # Design the frequency masking schedule to prioritize central k-space frequencies early in the process. + # Central k-space contains low-frequency information critical for image contrast. + + + # Generate the sampling rate list with torch.logspace, reversed, and skip the first element + sr_list = torch.logspace(start=-torch.log10(torch.tensor(accelerated_factor)), + end=0, steps=timesteps + 1).flip(0) + + af = 1 / sr_list[-1] + cf = center_fraction if use_fix_center_ratio else sr_list[-1] * center_ratio_factor + # print("af = ", af, cf) + + # Full + if isinstance(example_frequency_img, str): + # read in image and get frequency space: + example_img = plt.imread(example_frequency_img) #cv2.imread(example_frequency_img, cv2.IMREAD_GRAYSCALE) + print("example_img shape: ", example_img.shape) + plt.imshow(example_img, cmap='gray') + plt.title("Example Frequency Image") + plt.show() + + example_img = torch.from_numpy(example_img).float() + fre = fftshift(fft2(example_img) ) # ) + amp = torch.log(torch.abs(fre)) + plt.imshow(amp.squeeze(0).squeeze(0).numpy()) + plt.show() + angle= torch.angle(fre) + plt.imshow(angle.squeeze(0).squeeze(0).numpy()) + plt.show() + + + cache_mask = get_ksu_mask(mask_method, af, cf, pe=ksu_mask_pe, fe=ksu_mask_fe) + + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, ksu_mask_fe, ksu_mask_pe) + + masks.append(cache_mask) + + sr_list = sr_list[:-1].flip(0) # Flip? + + for sr in sr_list: + af = 1 / sr + cf = center_fraction if use_fix_center_ratio else sr * center_ratio_factor + # print("af = ", af, cf) + + H, W = cache_mask.shape[1], cache_mask.shape[2] + new_mask = cache_mask.clone() + + # Add additional lines to the mask based on new acceleration factor + total_lines = H + sampled_lines = int(total_lines / af) + existing_lines = new_mask.squeeze(0).sum(dim=0).nonzero(as_tuple=True)[0].tolist() + + remaining_lines = [i for i in range(total_lines) if i not in existing_lines] + + if sampled_lines > len(existing_lines): + center = W // 2 + additional_lines = sampled_lines - len(existing_lines) # sample number + + sorted_indices = sorted(remaining_lines, key=lambda x: abs(x - center)) + + # Take the closest `additional_lines` indices + sampled_indices = sorted_indices[:additional_lines] + + # Remove sampled indices from remaining_lines + for idx in sampled_indices: + remaining_lines.remove(idx) + + # Update new_mask for each sampled index + for idx in sampled_indices: + new_mask[:, :, idx] = 1.0 + + # if sampled_lines > len(existing_lines): + # additional_lines = sampled_lines - len(existing_lines) # sample number + # + # # Random line + # # sampled_indices = np.random.choice(remaining_lines, additional_lines, replace=False) + # + # # Close to the center + # center = W // 2 # Calculate the center index + # # Find the indices of zeros closest to the center + # sampled_indices = sorted(remaining_lines, key=lambda x: abs(x - center))[0] + # remaining_lines.remove(sampled_indices) + # + # # sampled_indices = + # new_mask[:, :, sampled_indices] = 1.0 + + + + cache_mask = new_mask + + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, ksu_mask_fe, ksu_mask_pe) + + + masks.append(cache_mask) + + # reverse + masks = masks[::-1] + noisy_masks = masks # noisy_masks[::-1] + + + elif mask_method == 'gaussian_2d': + raise NotImplementedError("Gaussian 2D mask type is not implemented.") + + else: + raise NotImplementedError(f'Unknown k-space undersampling routine {ksu_routine}') + + # Return masks, excluding the first one + return masks#, noisy_masks[1:] + + + +class high_fre_mask: + def __init__(self): + self.mask_cache = {} + + def __call__(self, H, W): + if (H, W) in self.mask_cache: + return self.mask_cache[(H, W)] + center_x, center_y = H // 2, W // 2 + radius = H//8 # 影响的频率范围半径 + + high_freq_mask = torch.ones(H, W) + for i in range(H): + for j in range(W): + if (i - center_x) ** 2 + (j - center_y) ** 2 <= radius ** 2: + high_freq_mask[i, j] = 0.0 + self.mask_cache[(H, W)] = high_freq_mask + return high_freq_mask + + +high_fre_mask_cls = high_fre_mask() + + + +def apply_ksu_kernel(x_start, mask, params_dict=None, pixel_range='mean_std', + use_fre_noise=False): + fft, mask = apply_tofre(x_start, mask) + + + # Use the high frequency mask to add noise + if use_fre_noise: + fft = fft * mask + fft_magnitude = torch.abs(fft) # 幅度 + fft_phase = torch.angle(fft) # 相位 + + _, _, H, W = fft.shape + + high_freq_mask = high_fre_mask_cls(H, W).to(fft.device) + high_freq_mask = high_freq_mask.unsqueeze(0).unsqueeze(0).repeat(fft.shape[0], 1, 1, 1) + + # Background Noise + sigma = 0.1 * torch.abs(torch.randn(1)).item() + noise = torch.randn_like(fft_magnitude) * sigma + # noise_magnitude = sigma * fft_magnitude # fft_magnitude.mean() + mean_mag = fft_magnitude.sum() / (mask.sum() + 1) + # print("mean_mag = ", mean_mag) + + noise_magnitude_high = noise * (mean_mag) * (1 - mask) # high_freq_mask + + # Add noise to unmasked frequencies to maintain the stochasticity that diffusion models typically rely on. + # This prevents overfitting to specific frequency ranges. + sigma = 5/255 * torch.abs(torch.randn(1)).item() + noise = torch.randn_like(fft_magnitude) * sigma + noise_magnitude_low = noise * fft_magnitude * mask # (1 - high_freq_mask) + + # fft_noisy_magnitude = fft_magnitude * mask + noise_magnitude * high_freq_mask * (1 - mask) + fft_noisy_magnitude = fft_magnitude * mask + fft_noisy_magnitude += noise_magnitude_low # + noise_magnitude_high + fft_noisy_magnitude = torch.clamp(fft_noisy_magnitude, min=0.0) + + fft = fft_noisy_magnitude * torch.exp(1j * fft_phase) + + else: + fft = fft * mask + + + x_ksu = apply_to_spatial(fft) + + return x_ksu + + + +def apply_tofre(x_start, mask): + kspace = fftshift(fft2(x_start, norm=None, dim=(-2, -1)), dim=(-2, -1)) # Default: all dimensions + mask = mask.to(kspace.device) + return kspace, mask + +def apply_to_spatial(fft): + + x_ksu = ifft2(ifftshift(fft, dim=(-2, -1)), norm=None, dim=(-2, -1)) # ortho + x_ksu = x_ksu.real #torch.abs(x_ksu) # + + return x_ksu + + +if __name__ == "__main__": + # First STEP + import matplotlib.pyplot as plt + import numpy as np + + image_size = 64 + + masks = get_ksu_kernel(25, image_size, + "LinearSamplingRate", is_training=True) # LogSamplingRate + + + batch_size = 1 + + img = plt.imread("/Users/haochen/Documents/GitHub/Frequency-Diffusion/draw/assets/BraTS20_Training_001_86_t1.png") + img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LINEAR) + # to gray scale + # img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + + # img = np.transpose(img, (2, 0, 1)) + # img = img[0] + img = np.expand_dims(img, axis=0) + img = torch.from_numpy(img).unsqueeze(0).float() + print(" img shape: ", img.shape, img.max(), img.min()) + + rand_kernels = [] + rand_x = torch.randint(0, image_size + 1, (batch_size,)).long() + print("rand_x shape:", rand_x.shape, rand_x) + + img = img * 2 - 1 # + + masked_img = [] + + # masks = np.asarray(masks) + for m in masks: + print("m shape: ", m.shape) + m = m.unsqueeze(0) + img = apply_ksu_kernel(img, m, pixel_range='-1_1', ) + masked_img.append(img) + + masks = np.concatenate(masks, axis=-1)[0] + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + # masked_img = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY) + + print(" masked_img shape: ", masked_img.shape) + print(" mask shape: ", masks.shape) + + img = np.concatenate([masks, masked_img], axis=0) + + plt.figure(figsize=(100, 10)) + plt.imshow(img, cmap='gray') # (1, 128, 1280) + plt.show() + + print("\n\nSecond stage...") + + + # Second STEP + import matplotlib.pyplot as plt + import numpy as np + + image_size = 64 + batch_size = 1 + t = 25 + kspace_kernels = get_ksu_kernel(t, image_size, ksu_routine="LogSamplingRate", is_training=True) # 2 * + kspace_kernels = torch.stack(kspace_kernels).squeeze(1) + + img = plt.imread( + "/Users/haochen/Documents/GitHub/Frequency-Diffusion/draw/assets/BraTS20_Training_001_86_t1.png") + img = cv2.resize(img, (image_size, image_size)) + + # img = np.transpose(img, (2, 0, 1)) + # img = img[0] + + img = np.expand_dims(img, axis=0) + img = torch.from_numpy(img).unsqueeze(0).float() + print(" img shape: ", img.shape, img.max(), img.min()) + + rand_kernels = [] + rand_x = torch.randint(0, image_size + 1, (batch_size,)).long() + + print("rand_x shape:", rand_x.shape, rand_x) + + for i in range(batch_size): + print("kspace_kernels[j] shape = ", kspace_kernels[i].shape, rand_x[i]) + # k = kspace_kernels[j].clone() + rand_kernels.append(torch.stack( + [kspace_kernels[j][: image_size, # rand_x[i]:rand_x[i] + + : image_size] for j in + range(len(kspace_kernels))])) + + rand_kernels = torch.stack(rand_kernels) + + # rand_kernels shape: torch.Size([24, 5, 128, 128]) + print("=== rand_kernels: ", rand_kernels.shape, kspace_kernels[0].shape) + + masked_img = [] + masks = [] + for i in range(t): + k = torch.stack([rand_kernels[:, i]], 1)[0] + masks.append(k) + # print("-- k shape: ", k.shape) + # print("-- img shape: ", img.shape) + + img = apply_ksu_kernel(img, k, pixel_range='0_1') + masked_img.append(img) + + masks = np.concatenate(masks, axis=-1)[0] + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + # masked_img = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY) + + # print(" masked_img shape: ", masked_img.shape) + # print(" mask shape: ", masks.shape) + + img = np.concatenate([masks, masked_img], axis=0) + + plt.figure(figsize=(100, 10)) + plt.imshow(img, cmap='gray') # (1, 128, 1280) + plt.show() + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/kspace_test.py b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/kspace_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1b00fc4c3af61773497301d2bc5344642c1b4a9a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/kspace_test.py @@ -0,0 +1,272 @@ +# --------------------------- +# Fade kernels +# --------------------------- +import cv2, torch +import numpy as np +import torchgeometry as tgm +from torch.fft import fft2, ifft2, fftshift, ifftshift +import matplotlib.pyplot as plt + +try: + from diffusion_pytorch.degradation.mask_utils import RandomMaskFunc, EquispacedMaskFractionFunc, EquiSpacedMaskFunc, \ + RandomPatchFunc +except: + from mask_utils import RandomMaskFunc, EquispacedMaskFractionFunc, EquiSpacedMaskFunc, RandomPatchFunc + +try: + from .k_degradation import get_fade_kernel, get_fade_kernels, get_mask_func, high_fre_mask, get_ksu_kernel +except: + from k_degradation import get_fade_kernel, get_fade_kernels, get_mask_func, high_fre_mask, get_ksu_kernel + + +use_fix_center_ratio = False + + +def get_ksu_mask(mask_method, af, cf, pe, fe, seed=0, is_training=False): + mask_func = get_mask_func(mask_method, af, cf) # acceleration factor, center fraction + + if mask_method in ['cartesian_regular', 'cartesian_random']: + + mask, num_low_freq = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + if is_training: # Step 1, Random Drop elements to learn intra-stride effects, patch-wise + af_new = 1.0 + (af - 1.0) / 2 + # af_new = max(af_new, 1.0) + + patch_mask = get_mask_func("randompatch", af_new, cf) + mask_, _ = patch_mask((fe, pe, 1), seed=seed) # mask (numpy): (fe, pe) + + mask = mask_ * mask + + + elif mask_method == 'gaussian_2d': + mask, _ = mask_func(resolution=(fe, pe), accel=af, sigma=100, seed=seed) # mask (numpy): (fe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (fe, pe) --> (1, fe, pe) + + elif mask_method in ['radial_add', 'radial_sub', 'spiral_add', 'spiral_sub']: + sr = 1 / af + mask = mask_func(mask_type=mask_method, mask_sr=sr, res=pe, seed=seed) # mask (numpy): (pe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (pe, pe) --> (1, pe, pe) + + else: + raise NotImplementedError + + # print("return mask = ", mask.shape) + return mask + + +# ksu_masks = get_ksu_kernels() +# (C, H, W) --> (B, C, H, W) + + +high_fre_mask_cls = high_fre_mask() + + +def apply_ksu_kernel(x_start, mask, params_dict=None, pixel_range='mean_std', + use_fre_noise=False, return_mask=False): + fft, mask = apply_tofre(x_start, mask, params_dict, pixel_range) + + # Use the high frequency mask to add noise + if use_fre_noise: + fft = fft * mask + fft_magnitude = torch.abs(fft) # 幅度 + fft_phase = torch.angle(fft) # 相位 + + _, _, H, W = fft.shape + + high_freq_mask = high_fre_mask_cls(H, W).to(fft.device) + high_freq_mask = high_freq_mask.unsqueeze(0).unsqueeze(0).repeat(fft.shape[0], 1, 1, 1) + + # Background Noise + sigma = 0.2 + noise = torch.randn_like(fft_magnitude) * sigma + mean_mag = fft_magnitude.sum() / (mask.sum() + 1) + + noise_magnitude_high = noise * (mean_mag) * (1 - mask) # high_freq_mask + + sigma = 0.1 + noise = torch.randn_like(fft_magnitude) * sigma + noise_magnitude_low = noise * fft_magnitude * mask # (1 - high_freq_mask) + + # fft_noisy_magnitude = fft_magnitude * mask + noise_magnitude * high_freq_mask * (1 - mask) + fft_noisy_magnitude = fft_magnitude * mask + fft_noisy_magnitude += noise_magnitude_high + noise_magnitude_low + fft_noisy_magnitude = torch.clamp(fft_noisy_magnitude, min=0.0) + + fft = fft_noisy_magnitude * torch.exp(1j * fft_phase) + + else: + fft = fft * mask + + x_ksu = apply_to_spatial(fft, params_dict, pixel_range) + if return_mask: + return x_ksu, fft, fft_magnitude + + return x_ksu + + +def apply_tofre(x_start, mask, params_dict=None, pixel_range='mean_std'): + fft = fftshift(fft2(x_start)) + mask = mask.to(fft.device) + return fft, mask # , _min, _max + + +def apply_to_spatial(fft, params_dict=None, pixel_range='mean_std'): + x_ksu = ifft2(ifftshift(fft)) + x_ksu = torch.abs(x_ksu) + + return x_ksu + + +if __name__ == "__main__": + # First STEP + import SimpleITK as sitk + + import numpy as np + import os + + image_size = 240 + batch_size = 1 + t = 5 + + + use_linux = True + + # Load MRI back here + if use_linux: + root = "/gamedrive/Datasets/medical/Brain/brats/ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData" + p_id = 639 + modality = "T1C" + filename = f"{root}/BraTS-GLI-{p_id:05d}-000/BraTS-GLI-{p_id:05d}-000-{modality.lower()}.nii.gz" + img_obj = sitk.ReadImage(filename) + img_array = sitk.GetArrayFromImage(img_obj) + + slice = img_array.shape[0] // 2 + img = img_array[slice, ...] + plt.imshow(img, cmap="gray") + plt.show() + img = (img - img.min()) / (img.max() - img.min()) + + plt.imsave("visualization/original.png", img, cmap="gray") + + else: + # Or use PNG + img = plt.imread( + "/Users/haochen/Documents/GitHub/DiffDecomp/Cold-Diffusion/generation-diffusion-pytorch/diffusion_pytorch/assets/img.png") + img = np.transpose(img, (2, 0, 1))[0] + + + img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LINEAR) + print("img shape=", img.shape) + + img = np.expand_dims(img, axis=0) + img = torch.from_numpy(img).unsqueeze(0).float() + + ksu_routine = "LogSamplingRate" # "LinearSamplingRate" # + kspace_kernels, patch_drop_masks = get_ksu_kernel(t, image_size, + ksu_routine=ksu_routine, is_training=True, + example_frequency_img=example) + kspace_kernels = torch.stack(kspace_kernels).squeeze(1) + + # all k_space + rand_kernels = [] + rand_x = torch.randint(0, image_size + 1, (batch_size,)).long() + + for i in range(batch_size): + # k = kspace_kernels[j].clone() + rand_kernels.append(torch.stack( + [kspace_kernels[j][: image_size, # rand_x[i]:rand_x[i] + + : image_size] for j in + range(len(kspace_kernels))])) + + rand_kernels = torch.stack(rand_kernels) + + # rand_kernels shape: torch.Size([24, 5, 128, 128]) + ori_img = img + + masked_img = [] + masks = [] + for i in range(t): + k = torch.stack([rand_kernels[:, i]], 1)[0] + masks.append(k) + + img = apply_ksu_kernel(img, k, pixel_range='0_1') + + masked_img.append(img) + + # Visualize the masks + masks = np.concatenate(masks, axis=-1)[0] + + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + # Save individually + + print("masks / masked_img=", masks.max(), masked_img.max()) + # img = np.concatenate([masks, masked_img], axis=0) + + plt.imsave("visualization/sample_masks.png", masks, cmap='gray') + + # masked_img = (masked_img - masked_img.min())/(masked_img) + # masked_img = np.concatenate([masked_img, 1-masked_img], axis=0) + plt.imsave("visualization/sample_images.png", masked_img, cmap='gray') + + w = masked_img.shape[0] + pr_folder = "visualization/progressive" + os.makedirs(pr_folder, exist_ok=True) + + # Progressive + print() + for i in range(t): + plt.imsave(f"{pr_folder}/{i}_masked_img.png", masked_img[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_masks.png", masks[:, i * w: (i + 1) * w], cmap='gray') + + img = ori_img + # use_fre_noise=False, return_mask=False + masked_img = [] + masks = [] + fft = [] + ks = [] + for i in range(t): + k = torch.stack([rand_kernels[:, i]], 1)[0] + ks.append(k) + + img, k, fft_original = apply_ksu_kernel(img, k, pixel_range='0_1', use_fre_noise=True, return_mask=True) + + # k -> fft + fft_magnitude = np.abs(k) # 幅度 + # fft_phase = torch.angle(k) # 相位 + + mag = np.log(fft_magnitude[0]) + masks.append(mag) + fft.append(np.log(fft_original[0])) + + masked_img.append(img) + + # Visualize the masks + masks = np.concatenate(masks, axis=-1)[0] + ks = np.concatenate(ks, axis=-1)[0] + + masked_img = (torch.concat(masked_img, dim=-1).numpy() + 1) * 0.5 + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + + fft = np.concatenate(fft, axis=-1)[0] + + plt.imsave("visualization/sample_noisy_mask.png", masks, cmap='gray') + + # masked_img = np.concatenate([masked_img, 1 - masked_img], axis=0) + plt.imsave("visualization/sample_noisy_image.png", masked_img, cmap='gray') + # print("masked_img shape=", masked_img.shape, w) + + # Progressive + for i in range(t): + # print("masked_img[:, t*w: (t+1)*w] = ", masked_img[:, t*w: (t+1)*w].shape, t*w) + + plt.imsave(f"{pr_folder}/{i}_n_masked_img.png", masked_img[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_n_masks.png", masks[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_n_fft.png", fft[:, i * w: (i + 1) * w], cmap='gray') + plt.imsave(f"{pr_folder}/{i}_n_ks.png", ks[:, i * w: (i + 1) * w], cmap='gray') + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/mask_utils.py b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2fea4433a26d0e67e3f81119a67d43a6e46598 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/mask_utils.py @@ -0,0 +1,680 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import contextlib +from typing import Optional, Sequence, Tuple, Union + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng: np.random.RandomState, seed: Optional[Union[int, Tuple[int, ...]]]): + """A context manager for temporarily adjusting the random seed.""" + if seed is None: + try: + yield + finally: + pass + else: + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +class MaskFunc: + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + + When called, ``MaskFunc`` uses internal functions create mask by 1) + creating a mask for the k-space center, 2) create a mask outside of the + k-space center, and 3) combining them into a total mask. The internals are + handled by ``sample_mask``, which calls ``calculate_center_mask`` for (1) + and ``calculate_acceleration_mask`` for (2). The combination is executed + in the ``MaskFunc`` ``__call__`` function. + + If you would like to implement a new mask, simply subclass ``MaskFunc`` + and overwrite the ``sample_mask`` logic. See examples in ``RandomMaskFunc`` + and ``EquispacedMaskFunc``. + """ + + def __init__( + self, + center_fractions: Sequence[float], + accelerations: Sequence[int], + allow_any_combination: bool = False, + seed: Optional[int] = None, + ): + """ + Args: + center_fractions: Fraction of low-frequency columns to be retained. + If multiple values are provided, then one of these numbers is + chosen uniformly each time. + accelerations: Amount of under-sampling. This should have the same + length as center_fractions. If multiple values are provided, + then one of these is chosen uniformly each time. + allow_any_combination: Whether to allow cross combinations of + elements from ``center_fractions`` and ``accelerations``. + seed: Seed for starting the internal random number generator of the + ``MaskFunc``. + """ + if len(center_fractions) != len(accelerations) and not allow_any_combination: + raise ValueError( + "Number of center fractions should match number of accelerations " + "if allow_any_combination is False." + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.allow_any_combination = allow_any_combination + self.rng = np.random.RandomState(seed) + + def __call__( + self, + shape: Sequence[int], + offset: Optional[int] = None, + seed: Optional[Union[int, Tuple[int, ...]]] = None, + ) -> Tuple[torch.Tensor, int]: + """ + Sample and return a k-space mask. + + Args: + shape: Shape of k-space. + offset: Offset from 0 to begin mask (for equispaced masks). If no + offset is given, then one is selected randomly. + seed: Seed for random number generator for reproducibility. + + Returns: + A 2-tuple containing 1) the k-space mask and 2) the number of + center frequency lines. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_mask, accel_mask, num_low_frequencies = self.sample_mask( + shape, offset + ) + + # combine masks together + return torch.max(center_mask, accel_mask), num_low_frequencies + + def sample_mask( + self, + shape: Sequence[int], + offset: Optional[int], + ) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Sample a new k-space mask. + + This function samples and returns two components of a k-space mask: 1) + the center mask (e.g., for sensitivity map calculation) and 2) the + acceleration mask (for the edge of k-space). Both of these masks, as + well as the integer of low frequency samples, are returned. + + Args: + shape: Shape of the k-space to subsample. + offset: Offset from 0 to begin mask (for equispaced masks). + + Returns: + A 3-tuple contaiing 1) the mask for the center of k-space, 2) the + mask for the high frequencies of k-space, and 3) the integer count + of low frequency samples. + """ + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + num_low_frequencies = round(float(num_cols * center_fraction)) + center_mask = self.reshape_mask( + self.calculate_center_mask(shape, num_low_frequencies), shape + ) + acceleration_mask = self.reshape_mask( + self.calculate_acceleration_mask( + num_cols, acceleration, offset, num_low_frequencies + ), + shape, + ) + + return center_mask, acceleration_mask, num_low_frequencies + + def reshape_mask(self, mask: np.ndarray, shape: Sequence[int]) -> torch.Tensor: + """Reshape mask to desired output shape.""" + num_cols = shape[-2] + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + if isinstance(mask, torch.Tensor): + return mask.view(*mask_shape).to(torch.float32) + return torch.from_numpy(mask.reshape(*mask_shape)).to(torch.float32) # torch.from_numpy( + + def calculate_acceleration_mask( + self, + num_cols: int, + acceleration: int, + offset: Optional[int], + num_low_frequencies: int, + ) -> np.ndarray: + """ + Produce mask for non-central acceleration lines. + + Args: + num_cols: Number of columns of k-space (2D subsampling). + acceleration: Desired acceleration rate. + offset: Offset from 0 to begin masking (for equispaced masks). + num_low_frequencies: Integer count of low-frequency lines sampled. + + Returns: + A mask for the high spatial frequencies of k-space. + """ + raise NotImplementedError + + def calculate_center_mask( + self, shape: Sequence[int], num_low_freqs: int + ) -> np.ndarray: + """ + Build center mask based on number of low frequencies. + + Args: + shape: Shape of k-space to mask. + num_low_freqs: Number of low-frequency lines to sample. + + Returns: + A mask for hte low spatial frequencies of k-space. + """ + num_cols = shape[-2] + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = 1 + assert mask.sum() == num_low_freqs + + return mask + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + if self.allow_any_combination: + return self.rng.choice(self.center_fractions), self.rng.choice( + self.accelerations + ) + else: + choice = self.rng.randint(len(self.center_fractions)) + return self.center_fractions[choice], self.accelerations[choice] + + + + +class RandomMaskFunc(MaskFunc): + """ + Creates a random sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the ``RandomMaskFunc`` object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def calculate_acceleration_mask( + self, + num_cols: int, + acceleration: int, + offset: Optional[int], + num_low_frequencies: int, + ) -> np.ndarray: + + prob = (num_cols / acceleration - num_low_frequencies) / ( + num_cols - num_low_frequencies + ) + + # mask = self.rng.uniform(size=num_cols) < prob + # return torch.from_numpy(mask.astype(np.float32)) + + # return self.rng.uniform(size=num_cols) < prob + return torch.rand(num_cols) < prob + + + # mask = self.rng.uniform(size=num_cols) < prob + # pad = (num_cols - num_low_freqs + 1) // 2 + # mask[pad: pad + num_low_freqs] = True + # + # # reshape the mask + # mask_shape = [1 for _ in shape] + # mask_shape[-2] = num_cols + # mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + + + +class RandomPatchFunc(MaskFunc): + """ + Creates a random sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the ``RandomMaskFunc`` object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def reshape_mask(self, mask: np.ndarray, shape: Sequence[int]) -> torch.Tensor: + """Reshape mask to desired output shape.""" + # num_cols = shape[0] * shape[1] + mask_shape = [1 for _ in shape] + mask_shape[-2] = shape[0] #num_cols + mask_shape[-1] = shape[1] + + if isinstance(mask, torch.Tensor): + return mask.view(*mask_shape).to(torch.float32) + return torch.from_numpy(mask.reshape(*mask_shape)).to(torch.float32) # torch.from_numpy( + + + def calculate_acceleration_mask( + self, + num_cols: int, + acceleration: int, + offset: Optional[int], + num_low_frequencies: int, + ) -> np.ndarray: + + prob = (num_cols / acceleration - num_low_frequencies) / ( + num_cols - num_low_frequencies + ) + + # mask = self.rng.uniform(size=num_cols) < prob + # return torch.from_numpy(mask.astype(np.float32)) + + # return self.rng.uniform(size=num_cols) < prob + return torch.rand(num_cols) < prob + + + def calculate_center_mask( + self, shape: Sequence[int], num_low_freqs: int + ) -> np.ndarray: + """ + Build center mask based on number of low frequencies. + + Args: + shape: Shape of k-space to mask. + num_low_freqs: Number of low-frequency lines to sample. + + Returns: + A mask for hte low spatial frequencies of k-space. + """ + # print("shape = ", shape) + num_cols = shape[0] * shape[1] + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad: pad + num_low_freqs] = 1 + assert mask.sum() == num_low_freqs + + return mask + + + def sample_mask( + self, + shape: Sequence[int], + offset: Optional[int], + ) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Sample a new k-space mask. + + This function samples and returns two components of a k-space mask: 1) + the center mask (e.g., for sensitivity map calculation) and 2) the + acceleration mask (for the edge of k-space). Both of these masks, as + well as the integer of low frequency samples, are returned. + + Args: + shape: Shape of the k-space to subsample. + offset: Offset from 0 to begin mask (for equispaced masks). + + Returns: + A 3-tuple contaiing 1) the mask for the center of k-space, 2) the + mask for the high frequencies of k-space, and 3) the integer count + of low frequency samples. + """ + # print("sample mask shape= ", shape) + + + num_cols = shape[1] * shape[0] + center_fraction, acceleration = self.choose_acceleration() + num_low_frequencies = round(float(num_cols * center_fraction)) + center_mask = self.reshape_mask( + self.calculate_center_mask(shape, num_low_frequencies), shape + ) + acceleration_mask = self.reshape_mask( + self.calculate_acceleration_mask( + num_cols, acceleration, offset, num_low_frequencies + ), + shape, + ) + + return center_mask, acceleration_mask, num_low_frequencies + + + def __call__( + self, + shape: Sequence[int], + offset: Optional[int] = None, + seed: Optional[Union[int, Tuple[int, ...]]] = None, + ) -> Tuple[torch.Tensor, int]: + """ + Sample and return a k-space mask. + + Args: + shape: Shape of k-space. + offset: Offset from 0 to begin mask (for equispaced masks). If no + offset is given, then one is selected randomly. + seed: Seed for random number generator for reproducibility. + + Returns: + A 2-tuple containing 1) the k-space mask and 2) the number of + center frequency lines. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_mask, accel_mask, num_low_frequencies = self.sample_mask( + shape, offset + ) + + # combine masks together + return torch.max(center_mask, accel_mask), num_low_frequencies + + + + + + +class EquiSpacedMaskFunc(MaskFunc): + """ + Sample data with equally-spaced k-space lines. + + The lines are spaced exactly evenly, as is done in standard GRAPPA-style + acquisitions. This means that with a densely-sampled center, + ``acceleration`` will be greater than the true acceleration rate. + """ + + def calculate_acceleration_mask( + self, + num_cols: int, + acceleration: int, + offset: Optional[int], + num_low_frequencies: int, + ) -> np.ndarray: + """ + Produce mask for non-central acceleration lines. + + Args: + num_cols: Number of columns of k-space (2D subsampling). + acceleration: Desired acceleration rate. + offset: Offset from 0 to begin masking. If no offset is specified, + then one is selected randomly. + num_low_frequencies: Not used. + + Returns: + A mask for the high spatial frequencies of k-space. + """ + if not isinstance(acceleration, int): + acceleration = int(acceleration.item()) + + if offset is None: + offset = self.rng.randint(0, high=round(acceleration)) + + mask = np.zeros(num_cols, dtype=np.float32) + mask[offset::acceleration] = 1 + + return mask + + +class EquispacedMaskFractionFunc(MaskFunc): + """ + Equispaced mask with approximate acceleration matching. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def calculate_acceleration_mask( + self, + num_cols: int, + acceleration: int, + offset: Optional[int], + num_low_frequencies: int, + ) -> np.ndarray: + """ + Produce mask for non-central acceleration lines. + + Args: + num_cols: Number of columns of k-space (2D subsampling). + acceleration: Desired acceleration rate. + offset: Offset from 0 to begin masking. If no offset is specified, + then one is selected randomly. + num_low_frequencies: Number of low frequencies. Used to adjust mask + to exactly match the target acceleration. + + Returns: + A mask for the high spatial frequencies of k-space. + """ + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_frequencies - num_cols)) / ( + num_low_frequencies * acceleration - num_cols + ) + if offset is None: + offset = self.rng.randint(0, high=round(adjusted_accel)) + + mask = np.zeros(num_cols) + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = 1.0 + + return mask + + +class MagicMaskFunc(MaskFunc): + """ + Masking function for exploiting conjugate symmetry via offset-sampling. + + This function applies the mask described in the following paper: + + Defazio, A. (2019). Offset Sampling Improves Deep Learning based + Accelerated MRI Reconstructions by Exploiting Symmetry. arXiv preprint, + arXiv:1912.01101. + + It is essentially an equispaced mask with an offset for the opposite site + of k-space. Since MRI images often exhibit approximate conjugate k-space + symmetry, this mask is generally more efficient than a standard equispaced + mask. + + Similarly to ``EquispacedMaskFunc``, this mask will usually undereshoot the + target acceleration rate. + """ + + def calculate_acceleration_mask( + self, + num_cols: int, + acceleration: int, + offset: Optional[int], + num_low_frequencies: int, + ) -> np.ndarray: + """ + Produce mask for non-central acceleration lines. + + Args: + num_cols: Number of columns of k-space (2D subsampling). + acceleration: Desired acceleration rate. + offset: Offset from 0 to begin masking. If no offset is specified, + then one is selected randomly. + num_low_frequencies: Not used. + + Returns: + A mask for the high spatial frequencies of k-space. + """ + if offset is None: + offset = self.rng.randint(0, high=acceleration) + + if offset % 2 == 0: + offset_pos = offset + 1 + offset_neg = offset + 2 + else: + offset_pos = offset - 1 + 3 + offset_neg = offset - 1 + 0 + + poslen = (num_cols + 1) // 2 + neglen = num_cols - (num_cols + 1) // 2 + mask_positive = np.zeros(poslen, dtype=np.float32) + mask_negative = np.zeros(neglen, dtype=np.float32) + + mask_positive[offset_pos::acceleration] = 1 + mask_negative[offset_neg::acceleration] = 1 + mask_negative = np.flip(mask_negative) + + mask = np.concatenate((mask_positive, mask_negative)) + + return np.fft.fftshift(mask) # shift mask and return + + +class MagicMaskFractionFunc(MagicMaskFunc): + """ + Masking function for exploiting conjugate symmetry via offset-sampling. + + This function applies the mask described in the following paper: + + Defazio, A. (2019). Offset Sampling Improves Deep Learning based + Accelerated MRI Reconstructions by Exploiting Symmetry. arXiv preprint, + arXiv:1912.01101. + + It is essentially an equispaced mask with an offset for the opposite site + of k-space. Since MRI images often exhibit approximate conjugate k-space + symmetry, this mask is generally more efficient than a standard equispaced + mask. + + Similarly to ``EquispacedMaskFractionFunc``, this method exactly matches + the target acceleration by adjusting the offsets. + """ + + def sample_mask( + self, + shape: Sequence[int], + offset: Optional[int], + ) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Sample a new k-space mask. + + This function samples and returns two components of a k-space mask: 1) + the center mask (e.g., for sensitivity map calculation) and 2) the + acceleration mask (for the edge of k-space). Both of these masks, as + well as the integer of low frequency samples, are returned. + + Args: + shape: Shape of the k-space to subsample. + offset: Offset from 0 to begin mask (for equispaced masks). + + Returns: + A 3-tuple contaiing 1) the mask for the center of k-space, 2) the + mask for the high frequencies of k-space, and 3) the integer count + of low frequency samples. + """ + num_cols = shape[-2] + fraction_low_freqs, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_frequencies = round(num_cols * fraction_low_freqs) + + # bound the number of low frequencies between 1 and target columns + target_columns_to_sample = round(num_cols / acceleration) + num_low_frequencies = max(min(num_low_frequencies, target_columns_to_sample), 1) + + # adjust acceleration rate based on target acceleration. + adjusted_target_columns_to_sample = ( + target_columns_to_sample - num_low_frequencies + ) + adjusted_acceleration = 0 + if adjusted_target_columns_to_sample > 0: + adjusted_acceleration = round(num_cols / adjusted_target_columns_to_sample) + + center_mask = self.reshape_mask( + self.calculate_center_mask(shape, num_low_frequencies), shape + ) + accel_mask = self.reshape_mask( + self.calculate_acceleration_mask( + num_cols, adjusted_acceleration, offset, num_low_frequencies + ), + shape, + ) + + return center_mask, accel_mask, num_low_frequencies + + +def create_mask_for_mask_type( + mask_type_str: str, + center_fractions: Sequence[float], + accelerations: Sequence[int], +) -> MaskFunc: + """ + Creates a mask of the specified type. + + Args: + center_fractions: What fraction of the center of k-space to include. + accelerations: What accelerations to apply. + + Returns: + A mask func for the target mask type. + """ + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquiSpacedMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced_fraction": + return EquispacedMaskFractionFunc(center_fractions, accelerations) + elif mask_type_str == "magic": + return MagicMaskFunc(center_fractions, accelerations) + elif mask_type_str == "magic_fraction": + return MagicMaskFractionFunc(center_fractions, accelerations) + else: + raise ValueError(f"{mask_type_str} not supported") diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/original.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/original.png new file mode 100644 index 0000000000000000000000000000000000000000..8e9661372201bbd5c809478e5060f7c89c408c69 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/original.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/0_masks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/0_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..61c1d7d7c6ecbb701b85747cd6d2faf1e2fde9b6 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/0_masks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/0_n_fft.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/0_n_fft.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/0_n_fft.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/0_n_ks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/0_n_ks.png new file mode 100644 index 0000000000000000000000000000000000000000..61c1d7d7c6ecbb701b85747cd6d2faf1e2fde9b6 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/0_n_ks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/0_n_masked_img.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/0_n_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..1b74feb3ab2cf078b641e9369acc31db5dbd70e6 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/0_n_masked_img.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/0_n_masks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/0_n_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/0_n_masks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_masked_img.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..000a9816d664f7a6282d9da2ba81aabdd652103e Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_masked_img.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_masks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..d04fa7b7e31f5d43205e7b7f9d50afbd02a51c54 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_masks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_n_fft.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_n_fft.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_n_fft.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_n_ks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_n_ks.png new file mode 100644 index 0000000000000000000000000000000000000000..d04fa7b7e31f5d43205e7b7f9d50afbd02a51c54 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_n_ks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_n_masked_img.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_n_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..1f1e68643d94e2570045c15aced74c8f2e8c7823 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_n_masked_img.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_n_masks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_n_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/1_n_masks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_masked_img.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..3a323d0ddddc771fc6508161f93b919857c6eedb Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_masked_img.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_masks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..227825b2b4bf23e8d456e9b292af155e51f0c76f Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_masks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_n_fft.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_n_fft.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_n_fft.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_n_ks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_n_ks.png new file mode 100644 index 0000000000000000000000000000000000000000..227825b2b4bf23e8d456e9b292af155e51f0c76f Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_n_ks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_n_masked_img.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_n_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..2ff77d3c4df7dee5be4edc909c3980362b614f1d Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_n_masked_img.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_n_masks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_n_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/2_n_masks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_masked_img.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..a3d8492f7bbf1bb93c25f0e616f4000240f7fa5e Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_masked_img.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_masks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..3116c70ad5a72b9164c3da2f6fa2071e646372a2 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_masks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_n_fft.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_n_fft.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_n_fft.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_n_ks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_n_ks.png new file mode 100644 index 0000000000000000000000000000000000000000..3116c70ad5a72b9164c3da2f6fa2071e646372a2 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_n_ks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_n_masked_img.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_n_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..56c1877dc19f8ef25f99ffe435ba41940ac7f394 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_n_masked_img.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_n_masks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_n_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/3_n_masks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_masked_img.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..fb66bd61130361dc3ed8deb95ad6fb030135efd4 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_masked_img.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_masks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..7b6952955f0a71f402b4d1bd0851324cd79194ae Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_masks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_n_fft.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_n_fft.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_n_fft.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_n_ks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_n_ks.png new file mode 100644 index 0000000000000000000000000000000000000000..7b6952955f0a71f402b4d1bd0851324cd79194ae Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_n_ks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_n_masked_img.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_n_masked_img.png new file mode 100644 index 0000000000000000000000000000000000000000..28d1b93788c84e2579d53425432e920fb298c190 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_n_masked_img.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_n_masks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_n_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..1fd5c724fc8f12d41758f085aaabe1220e7d3e40 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/progressive/4_n_masks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/sample_images.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/sample_images.png new file mode 100644 index 0000000000000000000000000000000000000000..5547693cce0491b530cf67b50bf7bd91d5bd91c6 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/sample_images.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/sample_masks.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/sample_masks.png new file mode 100644 index 0000000000000000000000000000000000000000..0db686e17102eb660c650a4081ef841b374741cc Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/sample_masks.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/sample_noisy_image.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/sample_noisy_image.png new file mode 100644 index 0000000000000000000000000000000000000000..afedd4450bab040a01ffaa6af5af63617b2e0747 Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/sample_noisy_image.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/sample_noisy_mask.png b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/sample_noisy_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..ff7c980e2daf9395368f8a73b44d3bb2f8efdc9a Binary files /dev/null and b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/degradation/visualization/sample_noisy_mask.png differ diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/diffusion_gaussian.py b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/diffusion_gaussian.py new file mode 100644 index 0000000000000000000000000000000000000000..9560f60b2545cb40f4f903e884de42f0a014c8ce --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/diffusion_pytorch/diffusion_gaussian.py @@ -0,0 +1,2045 @@ +import copy, time +import gc + +import torch +from torch import nn +import torch.nn.functional as func +from inspect import isfunction +from functools import partial + +from torch.utils import data +from pathlib import Path +from torch.optim import AdamW, lr_scheduler +from torchvision import utils +import torch.nn.functional as F +# from einops import rearrange + +import os +import errno +from PIL import Image +# from pytorch_msssim import ssim +import cv2 +import numpy as np +import imageio +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + +# from torch.utils.tensorboard import SummaryWriter +from diffusion_pytorch.degradation.k_degradation import get_fade_kernels, get_ksu_kernel, apply_ksu_kernel, apply_tofre, apply_to_spatial +from dataset import Dataset, Dataset_Aug1, BrainDataset + +# from skimage.metrics import structural_similarity as ssim +from skimage.metrics import peak_signal_noise_ratio as psnr + +from torchmetrics.image import StructuralSimilarityIndexMeasure + +from metrics.lpips import LPIPS +from metrics.fid import calculate_fid +from metrics.fid_3d import calculate_fid_3d +from metrics.nmse import nmse +from diffusion_pytorch.degradation.k_degradation import get_noisy_patches +import torch.amp as amp +from torch.cuda.amp import GradScaler, autocast +from metrics.frequency_loss import AMPLoss +scaler = GradScaler() + +# try: +# from apex import amp +# +# APEX_AVAILABLE = True +# except: +# APEX_AVAILABLE = False + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def create_folder(path): + try: + os.mkdir(path) + except OSError as exc: + if exc.errno != errno.EEXIST: + raise + pass + + +def cycle(dl): + while True: + for inputs in dl: + yield inputs + + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + + +def loss_backwards(fp16, loss, optimizer, **kwargs): + if fp16: + scaler.scale(loss).backward(**kwargs) + else: + loss.backward(**kwargs) + + + +class EMA: + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + + + +class GaussianDiffusion(nn.Module): + def __init__( + self, + diffusion_type, + restore_fn, + *, + image_size, + device_of_kernel, + channels=3, + timesteps=1000, + loss_type='l1', + kernel_std=0.1, + initial_mask=11, + fade_routine='Incremental', + sampling_routine='default', + discrete=False, + accelerate_factor=4, + fp16=False, + normalizer="mean_std", + example_frequency_img=None + + ): + super().__init__() + self.fp16 = fp16 + self.channels = channels + self.image_size = image_size + self.restore_fn = restore_fn + self.accelerate_factor = accelerate_factor + + # self.backbone = diffusion_type.split('_')[0] + self.example_frequency_img = example_frequency_img + self.device_of_kernel = device_of_kernel + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + self.kernel_std = kernel_std + self.initial_mask = initial_mask + self.fade_routine = fade_routine + self.backbone = diffusion_type.split('_')[0] + self.degradation_type = diffusion_type.split('_')[1] + + self.sampling_routine = sampling_routine + self.discrete = discrete + + # Frequency Loss + if self.backbone == 'twobranch' or self.backbone == 'twounet': + self.amploss = AMPLoss() # .to(self.device, non_blocking=True) + self.kl_loss = torch.nn.KLDivLoss(reduction='sum') # sum , log_target=True + + self.lpips = LPIPS().eval().cuda() # .to(self.device, non_blocking=True) + self.ssim = StructuralSimilarityIndexMeasure() + + self.use_fre_loss = True + + self.use_ssim = False + self.use_lpips = True # 141 -> 144s + + self.clamp_every_sample = False # Stride + if normalizer == "min_max": + self.clamp_every_sample = True + + self.use_fre_noise = True + + self.update_kernel = False + self.use_patch_kernel = False + self.use_kl = False + + + if self.degradation_type == 'fade': + self.fade_kernels = get_fade_kernels(fade_routine, self.num_timesteps, image_size, kernel_std, initial_mask) + # print("=== self.fade_kernels shape = ", self.fade_kernels.shape) # [5, 256, 256] + + elif self.degradation_type == "kspace": + self.get_new_kspace(is_training=True) + # print("=== self.kspace_kernels shape = ", self.kspace_kernels.shape) # [5, 256, 256] + else: + raise NotImplementedError() + + def get_new_kspace(self, is_training=False): + # LinearSamplingRate, LogSamplingRate + self.kspace_kernels, self.noisy_kspace_kernels = get_ksu_kernel(self.num_timesteps, self.image_size, + ksu_routine="LogSamplingRate", is_training=is_training, + accelerated_factor=self.accelerate_factor, + example_frequency_img=self.example_frequency_img) + + self.kspace_kernels = torch.stack(self.kspace_kernels).squeeze(1).cuda() + self.noisy_kspace_kernels =self.kspace_kernels.cuda() + + + def get_kspace_kernels(self, index): + k = torch.stack([self.kspace_kernels[index]], 0).unsqueeze(0) + return k + + @torch.no_grad() + def sample(self, batch_size=16, faded_recon_sample=None, aux=None, + t=None, params_dict=None, sample_routine=None): + # Test + self.restore_fn.eval() + + if not sample_routine: + sample_routine = self.sampling_routine + + sample_device = faded_recon_sample.device + batch_size = faded_recon_sample.size(0) + + + + if t is None: + t = self.num_timesteps + + # print("self.kspace_kernels = ", self.kspace_kernels.shape) # self.kspace_kernels = torch.Size([5, 320, 320]) + # print("faded_recon_sample = ", faded_recon_sample.shape) # faded_recon_sample = torch.Size([16, 3, 320, 320]) + + # for i in range(t): + with torch.no_grad(): + k = torch.stack([self.kspace_kernels[[t - 1]]], 1) + # print("k = ", k.shape) # k = torch.Size([1, 1, 320, 320]) + faded_recon_sample = apply_ksu_kernel(faded_recon_sample, k, params_dict) + + return_k = k.repeat( batch_size, 1, 1, 1) + + xt = faded_recon_sample + # print("faded_recon_sample = ", faded_recon_sample.shape) + + + direct_recons = None + recon_sample = None + all_recons = [] + all_recons_fre = [] + all_masks = [] + + k_known_mask = torch.zeros_like(self.get_kspace_kernels(-1)).cuda() + + while t: + step = torch.full((batch_size,), t - 1, dtype=torch.long).cuda() + if self.backbone == "unet": + recon_sample = self.restore_fn(faded_recon_sample, step) + + elif self.backbone == "twounet": + recon_sample = self.restore_fn(faded_recon_sample, aux, k, step) + + elif self.backbone == "twobranch": + recon_sample, recon_fre = self.restore_fn(faded_recon_sample, aux, step) + all_recons_fre.append(recon_fre) + + if direct_recons is None: + direct_recons = recon_sample + all_recons.append(recon_sample) + + if self.degradation_type == 'kspace': + # faded_recon_sample = recon_sample + + if sample_routine == 'default': + all_recons.append(recon_sample) + with torch.no_grad(): + if t >=1: + k = self.get_kspace_kernels(t - 1) + recon_sample = apply_ksu_kernel(recon_sample, k, params_dict) + + all_masks.append(k) + faded_recon_sample = recon_sample + + + elif sample_routine == 'x0_step_down': + all_recons.append(recon_sample) + if t <= 1: + if t == 1: + # recon_sample_sub_1 = recon_sample + # k = self.get_kspace_kernels(0, self.kspace_kernels) + # + # recon_sample = apply_ksu_kernel(recon_sample, k, params_dict) + faded_recon_sample = recon_sample #faded_recon_sample - recon_sample + recon_sample_sub_1 + + else: + faded_recon_sample = recon_sample + all_masks.append(k) + else: + with torch.no_grad(): + k = self.get_kspace_kernels(t - 2, self.kspace_kernels) + recon_sample_sub_1 = apply_ksu_kernel(recon_sample, k, params_dict) + + k = self.get_kspace_kernels(t - 1, self.kspace_kernels) + recon_sample = apply_ksu_kernel(recon_sample, k, params_dict) + + faded_recon_sample = faded_recon_sample - recon_sample + recon_sample_sub_1 + + if self.clamp_every_sample: + faded_recon_sample = faded_recon_sample.clamp(-1, 1) + all_masks.append(k) + + elif sample_routine == 'x0_step_down_fre': + all_recons.append(recon_sample) + if t <= 1: + kt = self.get_kspace_kernels(0).cuda() # last one + faded_recon_sample = recon_sample + k_residual = torch.ones_like(kt).cuda() + + + else: + k_full = self.get_kspace_kernels(- 1) + faded_recon_sample_fre, k_full = apply_tofre(faded_recon_sample, k_full, params_dict) + # print('k_full = ', k_full.shape) + + with torch.no_grad(): + + kt_sub_1 = self.get_kspace_kernels(t - 1).cuda() + kt = self.get_kspace_kernels(t - 0).cuda() # last one + k_residual = kt_sub_1 - kt + recon_sample_fre, k_residual = apply_tofre(recon_sample, k_residual, params_dict) + + fre_amend = recon_sample_fre * k_residual + faded_recon_sample_fre = faded_recon_sample_fre + fre_amend # * (1-k_residual) + + faded_recon_sample_fre = faded_recon_sample_fre * kt_sub_1 + + faded_recon_sample = apply_to_spatial(faded_recon_sample_fre, params_dict) + + k_known_mask += k_residual #.cpu() + all_masks.append(k_known_mask.cpu().clone()) + + if self.clamp_every_sample: + faded_recon_sample = faded_recon_sample.clamp(-1, 1) + + + elif sample_routine == 'fre_progressive': + all_recons.append(recon_sample) + if t == 1: + recon_sample_sub_1 = recon_sample + k = self.get_kspace_kernels(0, self.kspace_kernels) + + recon_sample = apply_ksu_kernel(recon_sample, k, params_dict) + faded_recon_sample = faded_recon_sample - recon_sample + recon_sample_sub_1 + + elif t == 0: + faded_recon_sample = recon_sample + all_masks.append(k) + else: + k_full = self.get_kspace_kernels(- 1, self.kspace_kernels) + faded_recon_sample_fre, k_full = apply_tofre(faded_recon_sample, k_full, params_dict) + + with torch.no_grad(): + + kt_sub_1 = self.get_kspace_kernels(t - 2, self.kspace_kernels).cuda() + kt = self.get_kspace_kernels(t - 1, self.kspace_kernels).cuda() # last one + k_residual = kt_sub_1 - k_full + recon_sample_fre, k_residual = apply_tofre(recon_sample, k_residual, params_dict) + + + fre_amend = recon_sample_fre * k_residual # new + faded_recon_sample_fre = faded_recon_sample_fre * k_full + fre_amend # * (1-k_residual) + + # faded_recon_sample_fre = faded_recon_sample_fre * kt_sub_1 # + recon_sample * (1 - kt_sub_1) + + faded_recon_sample = apply_to_spatial(faded_recon_sample_fre, params_dict) + + k_known_mask += k_residual # .cpu() + all_masks.append(k_known_mask.cpu().clone()) + + if self.clamp_every_sample: + faded_recon_sample = faded_recon_sample.clamp(-1, 1) + + + + recon_sample = faded_recon_sample + # print("recon_sample = ", recon_sample.shape) + + t -= 1 + + all_recons = torch.stack(all_recons) + all_masks = torch.stack(all_masks) + all_recons_fre = torch.stack(all_recons_fre) + + return xt, direct_recons, recon_sample, return_k, all_recons, all_recons_fre, all_masks + + @torch.no_grad() + def all_sample(self, batch_size=16, faded_recon_sample=None, aux=None, t=None, params_dict=None, times=None): + # TODO + print("Running into all_sample...") + rand_kernels = None + sample_device = faded_recon_sample.device + if self.degradation_type == 'fade': + if 'Random' in self.fade_routine: + rand_kernels = [] + rand_x = torch.randint(0, self.image_size + 1, (batch_size,), device=faded_recon_sample.device).long() + rand_y = torch.randint(0, self.image_size + 1, (batch_size,), device=faded_recon_sample.device).long() + for i in range(batch_size, ): + rand_kernels.append(torch.stack( + [self.fade_kernels[j][rand_x[i]:rand_x[i] + self.image_size, + rand_y[i]:rand_y[i] + self.image_size] for j in range(len(self.fade_kernels))])) + rand_kernels = torch.stack(rand_kernels) + + elif self.degradation_type == 'kspace': + rand_kernels = [] + rand_x = torch.randint(0, self.image_size + 1, (batch_size,), device=faded_recon_sample.device).long() + + for i in range(batch_size, ): + rand_kernels.append(torch.stack( + [self.fade_kernels[j][rand_x[i]:rand_x[i] + self.image_size, + : self.image_size] for j in range(len(self.fade_kernels))])) + rand_kernels = torch.stack(rand_kernels) + + if t is None: + t = self.num_timesteps + if times is None: + times = t + + for i in range(t): + with torch.no_grad(): + if self.degradation_type == 'fade': + if 'Random' in self.fade_routine: + faded_recon_sample = torch.stack([rand_kernels[:, i].to(sample_device), + rand_kernels[:, i].to(sample_device), + rand_kernels[:, i].to(sample_device)], 1) * faded_recon_sample + else: + faded_recon_sample = self.fade_kernels[i].to(sample_device) * faded_recon_sample + elif self.degradation_type == 'kspace': + if rand_kernels is not None: + # print(f"kspace randkeynel k={rand_kernels[:, i].shape}, x={x.shape}") + k = torch.stack([rand_kernels[:, i]], 1) + faded_recon_sample = apply_ksu_kernel(faded_recon_sample, k, params_dict) + else: + # print(f"kspace k={self.kspace_kernels[i].shape}, x={x.shape}") + k = self.kspace_kernels[i] + faded_recon_sample = apply_ksu_kernel(faded_recon_sample, k, params_dict) + + if self.discrete: + faded_recon_sample = (faded_recon_sample + 1) * 0.5 + faded_recon_sample = (faded_recon_sample * 255) + faded_recon_sample = faded_recon_sample.int().float() / 255 + faded_recon_sample = faded_recon_sample * 2 - 1 + + x0_list = [] + xt_list = [] + + while times: + step = torch.full((batch_size,), times - 1, dtype=torch.long).cuda() + if self.backbone == "unet": + recon_sample = self.restore_fn(faded_recon_sample, step) + elif self.backbone == "twounet": + recon_sample = self.restore_fn(faded_recon_sample, aux, k, step) + + elif self.backbone == "twobranch": + recon_sample, recon_fre = self.restore_fn(faded_recon_sample, aux, step) + recon_sample = recon_sample #// 2 + recon_fre // 2 + + x0_list.append(recon_sample) + + if self.degradation_type == 'fade': + if self.sampling_routine == 'default': + for i in range(times - 1): + with torch.no_grad(): + if rand_kernels is not None: + recon_sample = torch.stack([rand_kernels[:, i].to(sample_device), + rand_kernels[:, i].to(sample_device), + rand_kernels[:, i].to(sample_device)], 1) * recon_sample + else: + recon_sample = self.fade_kernels[i].to(sample_device) * recon_sample + faded_recon_sample = recon_sample + + elif self.sampling_routine == 'x0_step_down': + for i in range(t): + with torch.no_grad(): + recon_sample_sub_1 = recon_sample + if rand_kernels is not None: + + recon_sample = apply_ksu_kernel(recon_sample, rand_kernels[i], params_dict) + else: + recon_sample = apply_ksu_kernel(recon_sample, self.kspace_kernels[i], params_dict) + + faded_recon_sample = faded_recon_sample - recon_sample + recon_sample_sub_1 + + elif self.degradation_type == 'kspace': + # faded_recon_sample = recon_sample + if self.sampling_routine == 'default': + for i in range(t - 1): + with torch.no_grad(): + if rand_kernels is not None: + k = torch.stack([rand_kernels[:, i]], 1) + recon_sample = apply_ksu_kernel(recon_sample, k, params_dict) + else: + recon_sample = apply_ksu_kernel(recon_sample, self.kspace_kernels[i], params_dict) + + faded_recon_sample = recon_sample + + elif self.sampling_routine == 'x0_step_down': + for i in range(t): + with torch.no_grad(): + recon_sample_sub_1 = recon_sample + if rand_kernels is not None: + k = torch.stack([rand_kernels[:, i]], 1) + recon_sample = apply_ksu_kernel(recon_sample, k, params_dict) + else: + recon_sample = apply_ksu_kernel(recon_sample, self.kspace_kernels[i], params_dict) + + faded_recon_sample = faded_recon_sample - recon_sample + recon_sample_sub_1 + + + xt_list.append(faded_recon_sample) + times -= 1 + + return x0_list, xt_list + + # Train + def q_sample(self, x_start, t, params_dict=None, use_fre_noise=False): + x = x_start + + with torch.no_grad(): + k = torch.stack([self.kspace_kernels[t]], 1) + x = apply_ksu_kernel(x, k, params_dict, use_fre_noise=use_fre_noise) # self.use_fre_noise + + return x, k + + + def reconstruct_loss(self, x_start, x_recon): + if self.loss_type == 'l1': + loss = (x_start - x_recon).abs().mean() + elif self.loss_type == 'l2': + loss = func.mse_loss(x_start, x_recon) + else: + raise NotImplementedError() + return loss + + + def gaussian_kernel(self, size: int, sigma: float): + """Generates a 2D Gaussian kernel.""" + x = torch.arange(size).float() - size // 2 + gauss = torch.exp(-x ** 2 / (2 * sigma ** 2)) + kernel = gauss[:, None] @ gauss[None, :] + kernel /= kernel.sum() + return kernel.cuda() + + def gaussian_blur(self, input_tensor, kernel_size: int, sigma: float): + """Applies Gaussian blur to a 4D tensor (N, C, H, W).""" + # Create Gaussian kernel + kernel = self.gaussian_kernel(kernel_size, sigma).unsqueeze(0).unsqueeze(0) + kernel = kernel.expand(input_tensor.size(1), 1, kernel_size, kernel_size) # For each channel + + # Pad the input tensor to avoid size reduction + padding = kernel_size // 2 + input_tensor = F.pad(input_tensor, (padding, padding, padding, padding), mode='reflect') + + # Apply convolution + blurred = F.conv2d(input_tensor, kernel, groups=input_tensor.size(1)) + return blurred + + def get_frequency_elements(self, x): + x_fft = torch.fft.rfft2(x, norm="ortho") + + # Perform FFT and compute magnitudes + x_mag = torch.clamp(torch.abs(x_fft), min=1e-8) + x_phase = torch.angle(x_fft) + + return x_mag, x_phase + + def get_fre_kl_loss(self, pred_spa, pred_fre, target, k): + # Flatten the elements + B = pred_spa.shape[0] + k = k.contiguous().view(B, -1) + pred_spa = pred_spa.view(B, -1) + pred_fre = pred_fre.view(B, -1) + target = target.view(B, -1) + + # minus max value + pred_spa = pred_spa - torch.max(pred_spa, dim=1, keepdim=True).values + pred_fre = pred_fre - torch.max(pred_fre, dim=1, keepdim=True).values + target = target - torch.max(target, dim=1, keepdim=True).values + + target_prob = F.softmax(target, dim=1) + + ele_num = 2 + target_all_prob = torch.cat([target_prob for ii in range(ele_num)], dim=0) + k_mask = torch.cat([k for ii in range(ele_num)], dim=0) + k_total = torch.sum(k_mask) + + pred_all = torch.cat([pred_spa, pred_fre]) # 4B + pred_all = F.log_softmax(pred_all, dim=1) + + + # consistency_loss + # get probability + pred_spa_prob = F.softmax(pred_spa, dim=1) + pred_fre_prob = F.softmax(pred_fre, dim=1) + pred_avg_prob = 1.0 / ele_num * (pred_spa_prob + pred_fre_prob) # 2 B + pred_avg_prob = torch.cat([pred_avg_prob for ii in range(ele_num)], dim=0).clone().detach() + + kl_consist_loss = (self.kl_loss(pred_all, pred_avg_prob) * k_mask).sum() / k_total + + return kl_consist_loss # + kl_loss + + + def frequency_consistency_loss(self, pred_spa, pred_fre, target, k): + ''' + KL-term, enforcing conditional distribution remains unchanged regardless of interventions applied + ''' + + W = pred_spa.shape[-1] + half_W = W // 2 + 1 + k = (1 - k.to(pred_spa.device)) # negative mask + k = k[..., :half_W] + + pred_spa_mag, pred_spa_pha = self.get_frequency_elements(pred_spa) + pred_fre_mag, pred_fre_pha = self.get_frequency_elements(pred_fre) + target_mag, target_pha = self.get_frequency_elements(target) + + mag_kl_loss = self.get_fre_kl_loss(pred_spa_mag, pred_fre_mag, target_mag, k) + pha_kl_loss = self.get_fre_kl_loss(pred_spa_pha, pred_fre_pha, target_pha, k) + + # print("mag loss=", mag_kl_loss, "pha loss=", pha_kl_loss) + return mag_kl_loss + pha_kl_loss + + + def p_losses(self, x_start, aux, t, params_dict): + self.debug_print = False + self.debug_time = False + + start_time = time.time() + + x_start_golden = x_start.clone() + x_mix, k = self.q_sample(x_start=x_start, t=t, params_dict=params_dict) + + # gaussian blur for x_mix + # if np.random.rand() > 0.5: + # x_mix = self.gaussian_blur( + # x_mix, + # kernel_size=int(torch.randint(1, 9, (1,)).item() * 2 + 1), # Ensure odd kernel size + # sigma=torch.abs(torch.randn(1) * 3.0).item() # Ensure sigma is positive + # ) + + # Add gaussian noise + # sigma = 0.1 * torch.abs(torch.rand(1)).item() # Standard Deviation + # x_mix = x_mix + torch.randn_like(x_mix) * sigma + # aux = aux + torch.randn_like(x_mix) * sigma + + x_mix = x_mix.detach() + aux = aux.detach() + x_start_golden = x_start_golden.detach() + k = k.detach() + + if self.debug_time: + print("--------------------") + print("sample time=", time.time() - start_time) # 0.02s ~ 0.03s + + + if self.backbone == 'unet': + x_recon = self.restore_fn(x_mix, t) + loss = self.reconstruct_loss(x_start_golden, x_recon) + + if self.use_lpips: + lpips_weight = 0.1 + lpips_loss = lpips_weight * self.lpips(x_recon, x_start_golden).mean() + loss += lpips_loss + + + if self.use_fre_loss: # NAN + fft_weight = 0.01 + + fre_loss = fft_weight * self.amploss(x_recon, x_start_golden, k) + loss += fre_loss + + + elif self.backbone == 'twounet': + x_recon = self.restore_fn(x_mix, aux, k, t) + loss = self.reconstruct_loss(x_start_golden, x_recon) * 5.0 + + # LPIPS + if self.use_lpips: + lpips_weight = 0.1 + lpips_loss = self.lpips(x_recon, x_start_golden).mean() + loss += lpips_weight * lpips_loss + + + if self.use_fre_loss: # NAN + fft_weight = 0.1 + amp = self.amploss(x_recon, x_start_golden, k) + loss += fft_weight * amp + + + elif self.backbone == 'twobranch': + + if self.fp16: + with autocast(): + x_recon, x_recon_fre = self.restore_fn(x_mix, aux, t) + else: + x_recon, x_recon_fre = self.restore_fn(x_mix, aux, t) + if self.debug_time: + print("restore_fn time=", time.time() - start_time) + + # img_mean = params_dict['img_mean'].cuda().view(-1, 1, 1, 1) + # img_std = params_dict['img_std'].cuda().view(-1, 1, 1, 1) + # x_start_golden = x_start_golden * img_std + img_mean # 0 - 1 + # x_recon = x_recon * img_std + img_mean + # x_recon_fre = x_recon_fre * img_std + img_mean + + loss_spatial = self.reconstruct_loss(x_start_golden, x_recon) + loss_freq = self.reconstruct_loss(x_start_golden, x_recon_fre) + loss = loss_spatial + loss_freq + + if self.debug_time: + print("reconstruct_loss time=", time.time() - start_time) + + # LPIPS + if self.use_lpips: + lpips_weight = 0.1 + lpips_loss = lpips_weight * self.lpips(x_recon, x_start_golden).mean() + loss += lpips_loss + + if self.use_ssim: + ssim_weight = 0.1 + ssim_loss = 1.0 - self.ssim(x_recon, x_start_golden).mean() + loss += ssim_weight * ssim_loss + + if self.use_fre_loss: # NAN + fft_weight = 0.01 + + fre_loss = fft_weight * self.amploss(x_recon_fre, x_start_golden, k) + loss += fre_loss + + # fre_loss = fft_weight * self.amploss(x_recon, x_start_golden, k) + # loss += fre_loss + + if self.use_kl: + amp_fre = fft_weight * self.frequency_consistency_loss(x_recon, x_recon_fre, x_start_golden, k) + loss += amp_fre + + if self.debug_time: + print("fre loss time=", time.time() - start_time) + print("--------------------") + + if np.random.rand() < 0.001: + print("----------------------------------------\n" + "loss_spatial:", loss_spatial.item(), + "loss_freq:", loss_freq.item(), + # "lpips_loss:", lpips_loss.item(), + # "ssim_loss", ssim_loss.item(), + "fre_loss:", fre_loss.item()) + + print("----------------------------------------\n" + "x_recon:", x_recon.min().item(), x_recon.max().item(), + "x_recon_fre:", x_recon_fre.min().item(), x_recon_fre.max().item(), + "x_start_golden:", x_start_golden.min().item(), x_start_golden.max().item()) + + return loss + + + def forward(self, x1, x2=None, params_dict=None, *args, **kwargs): + b, c, h, w, device, img_size, = *x1.shape, x1.device, self.image_size + assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + + loss = self.p_losses(x1, x2, t, params_dict, *args, **kwargs) + return loss + + + +class Trainer(nn.Module): + def __init__( + self, + diffusion_model, + folder, + mode, + *, + norm="mean_std", + ema_decay=0.995, + image_size=128, + train_batch_size=32, + train_lr=2e-5, + train_num_steps=700000, + gradient_accumulate_every=2, + fp16=False, + step_start_ema=2000, + update_ema_every=100, + save_and_sample_every=1000, + results_folder='./results', + load_path=None, + dataset=None, + shuffle=True, + domain=None, + aux_modality=None, + num_channels=1, + debug=False, + ): + super().__init__() + + self.mode = mode + self.model = diffusion_model + self.ema = EMA(ema_decay) + self.ema_model = copy.deepcopy(self.model) + self.update_ema_every = update_ema_every + + self.step_start_ema = step_start_ema + self.save_and_sample_every = save_and_sample_every if not debug else 10 + + self.batch_size = train_batch_size + self.image_size = diffusion_model.module.image_size + self.gradient_accumulate_every = gradient_accumulate_every + self.train_num_steps = train_num_steps + self.input_normalize = norm + + + if dataset == 'train': + print(dataset, "DA used") + self.ds = Dataset_Aug1(folder, image_size) + + elif dataset.lower() == 'brain': + print(dataset, "Brain DA used", "mode=", mode) + # mode, base_dir, image_size, nclass, domains, aux_modality, + self.ds = BrainDataset(mode, folder, image_size, 4, + debug=debug, + domains=domain, + num_channels=num_channels, + aux_modality=aux_modality) # mode, base_dir, domains: + + + + elif dataset.lower() == 'fsm_brain': + from dataset.BRATS_dataloader import Hybrid, ToTensor, RandomPadCrop, AddNoise + from torchvision import transforms + train_data_path = folder + + self.ds = Hybrid(split=mode, SNR=0, + # transform=transforms.Compose([RandomPadCrop(), ToTensor(), AddNoise()]), + transform=transforms.Compose([RandomPadCrop(), ToTensor()]), + base_dir=train_data_path, input_normalize=norm, + image_size=image_size, debug=debug) + + + self.test_ds = Hybrid(split="test", SNR=0, + # transform=transforms.Compose([RandomPadCrop(), ToTensor(), AddNoise()]), + transform=transforms.Compose([RandomPadCrop(), ToTensor()]), + base_dir=train_data_path, input_normalize=norm, + image_size=image_size, debug=debug) + + elif dataset.lower() == 'fsm_knee': + from dataset.fastmri import SliceDataset + root_path = folder + transforms = None + + # self.ds = build_dataset(args, mode='train') + self.ds = SliceDataset(root_path, transforms, 'singlecoil', + image_size=image_size, input_normalize=norm, + sample_rate=1, mode=mode, debug=debug) + + self.test_ds = SliceDataset(root_path, transforms, 'singlecoil', + image_size=image_size, input_normalize=norm, + sample_rate=1, mode="test", debug=debug) + + elif dataset.lower() == 'fsm_m4raw': + pass + + + + else: + print(dataset) + self.ds = Dataset(folder, image_size) + + self.train_batch_size = train_batch_size + self.batch_size = train_batch_size if mode == 'train' else 1 + + self.dl = cycle( + data.DataLoader(self.ds, + batch_size=train_batch_size if mode == 'train' else 1, + shuffle=(mode == "train"), + pin_memory=True, + num_workers=16, + drop_last=True)) + + self.test_dl = cycle( + data.DataLoader(self.ds, + batch_size=train_batch_size, + shuffle="test", + pin_memory=True, + num_workers=16, + drop_last=False)) + + self.opt = AdamW(list(self.model.module.restore_fn.parameters()), + lr=train_lr, + betas=(0.9, 0.999), + weight_decay=1e-4) + + self.scheduler = lr_scheduler.StepLR(self.opt, step_size=10000, gamma=0.5) + self.step = 0 + + # assert not fp16 or fp16 and APEX_AVAILABLE, 'Apex must be installed for mixed precision training on' + + self.fp16 = fp16 + os.makedirs(results_folder, exist_ok=True) + self.results_folder = Path(results_folder) + + np.save(str(self.results_folder / "kspace_kernels.npy"), self.model.module.kspace_kernels.cpu()) + + self.lpips = LPIPS().eval().cuda() + + self.reset_parameters() + + if load_path is not None: + self.load(load_path) + kspace_npy = load_path.replace('model.pt', 'kspace_kernels.npy') + self.model.module.kspace_kernels = torch.from_numpy(np.load(kspace_npy)).to(self.model.module.kspace_kernels.device) + self.ema_model.module.kspace_kernels = self.model.module.kspace_kernels + + def reset_parameters(self): + self.ema_model.load_state_dict(self.model.state_dict()) + + def step_ema(self): + if self.step < self.step_start_ema: + self.reset_parameters() + return + self.ema.update_model_average(self.ema_model, self.model) + + def save(self): + model_data = { + 'step': self.step, + 'model': self.model.state_dict(), + 'ema': self.ema_model.state_dict() + } + save_name = str(self.results_folder / f'model.pt') + print("Save_name=", save_name) + torch.save(model_data, save_name) + + def load(self, load_path): + print("Loading : ", load_path) + model_data = torch.load(load_path) + + self.step = model_data['step'] + self.model.load_state_dict(model_data['model']) + self.ema_model.load_state_dict(model_data['ema']) + print("Loading complete") + + @staticmethod + def add_title(path, title): + img1 = cv2.imread(path) + + black = [0, 0, 0] + constant = cv2.copyMakeBorder(img1, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=black) + height = 20 + violet = np.zeros((height, constant.shape[1], 3), np.uint8) + violet[:] = (255, 0, 180) + + vcat = cv2.vconcat((violet, constant)) + + font = cv2.FONT_HERSHEY_SIMPLEX + + cv2.putText(vcat, str(title), (violet.shape[1] // 2, height - 2), font, 0.5, (0, 0, 0), 1, 0) + cv2.imwrite(path, vcat) + + # + def calculate_metrics(self, all_images, og_img): + img_ = all_images.cpu() #.permute(0, 2, 3, 1).numpy()[..., 0] + og_img_ = og_img.cpu() #.permute(0, 2, 3, 1).numpy()[..., 0] + + # print("img_=", img_.shape, "og_img_=", og_img_.shape) # img_= torch.Size([4, 1, 240, 240]) og_img_= torch.Size([4, 1, 240, 240]) + + # B, C, H, W + # ssim = StructuralSimilarityIndexMeasure(data_range=255) + ssims_ = [] + for (img, og_img) in zip(img_, og_img_): + img_np = img.squeeze().numpy() # Convert to 2D + og_img_np = og_img.squeeze().numpy() # Convert to 2D + + # Compute SSIM for each pair of images + ssim_ = structural_similarity(og_img_np, img_np) + ssims_.append(ssim_) + + ssim_ = np.mean(ssims_) + + psnr_ = peak_signal_noise_ratio(og_img_.numpy(), img_.numpy()).mean() + nmse_ = nmse(og_img_.numpy(), img_.numpy()).mean() + + + return ssim_, psnr_, nmse_ + + # pip install pytorch-fid + + # (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8 + def calculate_metrics_3d(self, all_images, og_img): + all_images = torch.clamp(all_images, 1e-6, 1) + og_img = torch.clamp(og_img, 1e-6, 1) + + # img_ = torch.Size([5, 1, 64, 64]) torch.Size([5, 1, 64, 64] + all_images_new = all_images.unsqueeze(1).repeat(1, 3, 1, 1) + og_img_new = og_img.unsqueeze(1).repeat(1, 3, 1, 1) + + cal_fid = False + if cal_fid: + # (N, 3, 299, 299) + fid_value = calculate_fid(all_images_new.cpu().numpy(), og_img_new.cpu().numpy(), + use_multiprocessing=False, batch_size=og_img.shape[-1]) + # (N, 3, C, 256, 256) + fid_value_3d = calculate_fid_3d(all_images_new.cpu().numpy(), og_img_new.cpu().numpy(), + use_multiprocessing=False, batch_size=og_img.shape[-1]) + else: + fid_value = 0 + fid_value_3d = 0 + + # B, C, H, W + lpips = self.lpips(all_images_new.cuda(), og_img_new.cuda()).mean().item() + + # H, W, C + img_ = all_images.cpu().unsqueeze(0) # .permute(1, 2, 0).numpy() + og_img_ = og_img.cpu().unsqueeze(0) #.permute(1, 2, 0).numpy() #.numpy() + + # 0-1, H, W, C + ssim = StructuralSimilarityIndexMeasure(data_range=1.0) + ssim_ = ssim(og_img_, img_).mean() + psnr_ = psnr(og_img_.numpy(), img_.numpy(), data_range=1.0).mean() + + return ssim_, psnr_, fid_value, fid_value_3d, lpips + + # Evaluate all + def test_data_dict(self, tag, data_dict, batches, routine): + print("=== Test tag: ", tag) + + og_img = data_dict['img'].cuda() + aux = data_dict['aux'].cuda() + img_mean = data_dict['img_mean'].cuda() + img_std = data_dict['img_std'].cuda() + aux_mean = data_dict['aux_mean'].cuda() + aux_std = data_dict['aux_std'].cuda() + params_dict = {'img_mean': img_mean, 'img_std': img_std, 'aux_mean': aux_mean, 'aux_std': aux_std} + + + # print("input shape minmax, og_img:", og_img.min().item(), og_img.max().item(), + # "aux:", aux.min().item(), aux.max().item()) + + xt, direct_recons, all_images, return_k, all_recons, all_recons_fre, all_masks = ( + self.ema_model.module.sample( + batch_size=batches, + faded_recon_sample=og_img, + aux=aux, params_dict=params_dict, + sample_routine=routine)) + + img_std = img_std.view(-1, 1, 1, 1) + img_mean = img_mean.view(-1, 1, 1, 1) + aux_std = aux_std.view(-1, 1, 1, 1) + aux_mean = aux_mean.view(-1, 1, 1, 1) + + og_img = og_img * img_std + img_mean + + _min = og_img.min() + _max = og_img.max() + + og_img = (og_img - _min) / (_max - _min) + all_images = all_images * img_std + img_mean + all_images = (all_images - _min) / (_max - _min) + all_recons = all_recons * img_std + img_mean + all_recons = (all_recons - _min) / (_max - _min) + all_images = all_recons[-1] + + direct_recons = direct_recons * img_std + img_mean + direct_recons = (direct_recons - _min) / (_max - _min) + xt = xt * img_std + img_mean + xt = (xt - _min) / (_max - _min) + + aux = aux * aux_std + aux_mean + aux = (aux - aux.min()) / (aux.max() - aux.min()) + + # print("----------------------------------------\n" + # "all_recons:", all_recons.min().item(), all_recons.max().item(), + # "all_images:", all_images.min().item(), all_images.max().item(), + # "og_img:", og_img.min().item(), og_img.max().item(), + # "direct_recons:", direct_recons.min().item(), direct_recons.max().item()) + + all_recons = torch.clamp(all_recons, 1e-6, 1).mul(255).to(torch.int8) + direct_recons = torch.clamp(direct_recons, 1e-6, 1).mul(255).to(torch.int8) + all_images = torch.clamp(all_images, 1e-6, 1).mul(255).to(torch.int8) + og_img = torch.clamp(og_img, 1e-6, 1).mul(255).to(torch.int8) + + + + # 24, 1, 128, 128 + # Calculate SSIM and PSNR, LPIPS + ssims = [] + psnrs = [] + + ssim_, psnr_, nmse_ = self.calculate_metrics(og_img, direct_recons) + lpips = self.lpips(direct_recons.float()/255, og_img.float()/255).mean().item() + print(f"=== first step Metrics {routine}: SSIM: ", ssim_, " PSNR: ", psnr_, + " LPIPS: ", lpips, " NMSE: ", nmse_) + + for im in all_recons: + im = torch.clamp(im, 1e-6, 1).mul(255).to(torch.int8) + ssim_, psnr_, nmse_ = self.calculate_metrics(og_img, im) + ssims.append(ssim_) + psnrs.append(psnr_) + + + ssim_, psnr_, nmse_ = self.calculate_metrics(og_img, all_images) + lpips = self.lpips(all_images.float()/255, og_img.float()/255).mean().item() + + print(f"=== Final Metrics {routine}: SSIM: ", ssim_, " PSNR: ", psnr_, + " LPIPS: ", lpips, " NMSE: ", nmse_) + + os.makedirs(self.results_folder, exist_ok=True) + + # utils.save_image(xt, str(self.results_folder / f'{self.step}-xt-Noise.png'), nrow=6) + # utils.save_image(all_images, str(self.results_folder / f'{self.step}-full_recons.png'), + # nrow=6) + # utils.save_image(direct_recons, + # str(self.results_folder / f'{self.step}-sample-direct_recons.png'), nrow=6) + # utils.save_image(og_img, str(self.results_folder / f'{self.step}-img.png'), nrow=6) + # utils.save_image(aux, str(self.results_folder / f'{self.step}-aux.png'), nrow=6) + + # plot ssim and psnr in two sub plots same canvas parallel + import matplotlib.pyplot as plt + fig, axs = plt.subplots(2, figsize=(8, 6), dpi=100, sharex=True) + axs[0].plot(ssims, marker='o', linestyle='-', color='blue', label='SSIM') + axs[0].set_title('Structural Similarity Index (SSIM)', fontsize=14) + axs[0].set_ylabel('SSIM', fontsize=12) + axs[0].grid(True, linestyle='--', alpha=0.7) + axs[0].legend() + + # PSNR plot + axs[1].plot(psnrs, marker='o', linestyle='-', color='green', label='PSNR') + axs[1].set_title('Peak Signal-to-Noise Ratio (PSNR)', fontsize=14) + axs[1].set_xlabel('Iterations', fontsize=12) + axs[1].set_ylabel('PSNR (dB)', fontsize=12) + axs[1].grid(True, linestyle='--', alpha=0.7) + axs[1].legend() + + fig.tight_layout() + + plt.savefig(str(self.results_folder / f'{self.step}-metrics-{routine}.png')) + + return_k = return_k.cuda() + + + combine = torch.cat((return_k, + xt, + direct_recons, + all_images, + all_recons[-1].to(all_images.device), + og_img, aux), 2) + + utils.save_image(combine, str(self.results_folder / f'{self.step}-combine-{routine}.png'), nrow=6) + + # all_recon = all_recons[:, 0] # 50, 1, 128, 128 + # Ensure all_recons is on the CPU + + all_recons = torch.cat(list(all_recons), dim=-1).cpu() + all_masks = all_masks.cpu() + # all_masks = torch.cat(list(all_masks), dim=-1) + + s = all_recons.shape[-2] + repeats = all_recons.shape[3] // og_img.shape[3] # Calculate repeat factor + # tensor_small = tensor_small.repeat(1, 1, 1, repeats) + og_img = og_img.cpu() + all_recons_residual = all_recons - og_img.repeat(1, 1, 1, repeats) + all_recons_residual = (all_recons_residual - all_recons_residual.min()) / (all_recons_residual.max() - all_recons_residual.min()) + + # before and after residual + all_recons_residual_2 = all_recons[:, :, :, s:] - all_recons[:, :, :, :-s] + all_recons_residual_2 = (all_recons_residual_2 - all_recons_residual_2.min()) / (all_recons_residual_2.max() - all_recons_residual_2.min()) + padding = torch.zeros_like(all_recons_residual[:, :, :, :s // 2]) + all_recons_residual_2 = torch.cat([padding, all_recons_residual_2, padding], dim=-1) + + all_recons = torch.cat([all_recons, all_recons_residual_2, all_recons_residual], dim=-2) + + utils.save_image(all_recons, str(self.results_folder / f'{self.step}-all_recons-{routine}.png'), + nrow=1) + # utils.save_image(all_masks, str(self.results_folder / f'{self.step}-all_masks-{routine}.png'), + # nrow=1) + + # acc_loss = acc_loss / (self.save_and_sample_every + 1) + print(f'Mean of last {self.step}: save to :', str(self.results_folder / f'{self.step}-combine.png')) + + # acc_loss = 0 + + + def train(self): + backwards = partial(loss_backwards, self.fp16) + # writer = SummaryWriter() + + acc_loss = 0 + start_time = time.time() + + while self.step < self.train_num_steps: + d_time = time.time() + self.opt.zero_grad() + u_loss = 0 + + if (self.step + 1 )% 20000 == 0: + self.ds.update_chunk() + self.dl = cycle( + data.DataLoader(self.ds, + batch_size=self.train_batch_size, + shuffle=True, + pin_memory=True, + num_workers=16, + drop_last=True)) + self.test_dl = cycle( + data.DataLoader(self.test_ds, + batch_size=self.train_batch_size, + shuffle=False, + pin_memory=True, + num_workers=16, + drop_last=False)) + + self.debug_time = False + + for i in range(self.gradient_accumulate_every): + + last_model_state = self.model.state_dict() + optimizer_state = self.opt.state_dict() + + data_dict = next(self.dl) + if self.debug_time: + print("Data loading time=", time.time() - d_time) # Very slow, 0.5 s second on bask, 0.001 on local + + img = data_dict['img'].cuda() + aux = data_dict['aux'].cuda() + img_mean = data_dict['img_mean'].cuda() + img_std = data_dict['img_std'].cuda() + aux_mean = data_dict['aux_mean'].cuda() + aux_std = data_dict['aux_std'].cuda() + params_dict = {'img_mean': img_mean, 'img_std': img_std, 'aux_mean': aux_mean, 'aux_std': aux_std} + + + loss = torch.mean(self.model(img, aux, params_dict)) + if self.debug_time: + print("Model iter=", self.step, " time=", time.time() - d_time) + + if torch.isnan(loss).any(): + print(f"NaN encountered in step {self.step}. Reverting model.") + self.model.load_state_dict(last_model_state) # Revert model + self.opt.load_state_dict(optimizer_state) # Revert optimizer + continue # Skip the rest of this training step + if self.debug_time: + print("before loss=", time.time() - d_time) + u_loss += loss + loss.backward() + # backwards(loss / self.gradient_accumulate_every, self.opt) + if self.debug_time: + print("after loss=", time.time() - d_time) + + del img, aux, img_mean, img_std, aux_mean, aux_std, params_dict + + + + if (self.step + 1) % (min(self.train_num_steps // 100 + 1, 100)) == 0: + print('iteration %d [%.2f sec]: learning rate : %f loss : %f ' % ( + self.step + 1, time.time() - start_time, self.scheduler.get_lr()[0], loss.item())) + + # writer.add_scalar("Loss/train", loss.item(), self.step) + acc_loss = acc_loss + (u_loss.item() / self.gradient_accumulate_every) + + max_norm = 0.01 # Maximum norm for gradients + torch.nn.utils.clip_grad_norm_(self.model.module.restore_fn.parameters(), max_norm) + if self.debug_time: + print("Before Optimization time=", time.time() - d_time) + + if self.fp16: + scaler.step(self.opt) + self.scheduler.step() + scaler.update() + + else: + self.opt.step() + self.scheduler.step() + + if self.debug_time: + print("Optimization time=", time.time() - d_time) + + if self.step % self.update_ema_every == 0: + self.step_ema() + + + # TEST and SAVE + if self.step != 0 and (self.step + 1) % self.save_and_sample_every == 0: + batches = self.batch_size + data_dict = next(self.test_dl) # .cuda() + train_dict = next(self.dl) # .cuda() + + # 'default', "fre_progressive", "x0_step_down" + for routine in ['x0_step_down_fre']: + self.test_data_dict("Train", train_dict, batches, routine) + self.test_data_dict("Test", data_dict, batches, routine) + + self.ema_model.module.restore_fn.train() + self.save() + + self.step += 1 + clean_start = time.time() + + del data_dict + del u_loss + torch.cuda.empty_cache() + gc.collect() + if self.debug_time: + print("Clean time=", time.time() - clean_start) + + print("Iter time = ", time.time() - d_time, "total time = ", time.time() - start_time) + + print('training completed') + + def test_loader(self, sampling_routine): + """ + Computes patient-wise 3D for a dataset of MRI slices. + + """ + print("Starting testing with sampling routine: ", sampling_routine) + + model = self.model # self.ema_model # or self.model + model.eval() # Set model to evaluation mode + + # sampling_routine = ['default', 'x0_step_down', 'x0_step_down_fre', "fre_progressive"]: + num_timesteps = model.module.num_timesteps + patient_wise_pred = {} # num_timesteps + patient_wise_gt = [] + count = 1 + + while True: + batches = 1 #self.batch_size + data_dict = next(self.dl) # .cuda() + + + # Prediction + og_img = data_dict['img'].cuda() + aux = data_dict['aux'].cuda() + img_mean = data_dict['img_mean'].cuda() + img_std = data_dict['img_std'].cuda() + aux_mean = data_dict['aux_mean'].cuda() + aux_std = data_dict['aux_std'].cuda() + params_dict = {'img_mean': img_mean, 'img_std': img_std, 'aux_mean': aux_mean, 'aux_std': aux_std} + + xt, direct_recons, all_images, return_k, all_recons, all_recons_fre, all_masks = ( + model.module.sample( + batch_size=batches, + faded_recon_sample=og_img, + aux=aux, params_dict=params_dict, + sample_routine=sampling_routine)) + + img_std = img_std.view(-1, 1, 1, 1) + img_mean = img_mean.view(-1, 1, 1, 1) + aux_std = aux_std.view(-1, 1, 1, 1) + aux_mean = aux_mean.view(-1, 1, 1, 1) + + + print("before or_img shape: ", og_img.shape, og_img.min(), og_img.max()) + print("before direct_recons shape: ", direct_recons.shape, direct_recons.min(), direct_recons.max()) + # into a Normalized Image + direct_recons_norm = (direct_recons - direct_recons.mean()) / (direct_recons.std()) + all_images_norm = (all_images - all_images.mean()) / (all_images.std()) + + og_img = og_img * img_std + img_mean + _min = og_img.min() + _max = og_img.max() + + # og_img = (og_img - _min) / (_max - _min) # 0 - 1 + all_images = all_images * img_std + img_mean + # all_images = (all_images - _min) / (_max - _min) + all_images_norm = all_images_norm * img_std + img_mean + # all_images_norm = (all_images_norm - _min) / (_max - _min) + + all_recons = all_recons * img_std + img_mean + # all_recons = (all_recons - _min) / (_max - _min) + direct_recons = direct_recons * img_std + img_mean + # direct_recons = (direct_recons - _min) / (_max - _min) + + direct_recons_norm = direct_recons_norm * img_std + img_mean + # direct_recons_norm = (direct_recons_norm - _min) / (_max - _min) + + xt = xt * img_std + img_mean + # xt = (xt - _min) / (_max - _min) + + aux = aux * aux_std + aux_mean + # aux = (aux - aux.min()) / (aux.max() - aux.min()) + all_recons = all_recons.cpu() + + all_recons = torch.clamp(all_recons, 1e-6, 1).mul(255).to(torch.int8) + direct_recons = torch.clamp(direct_recons, 1e-6, 1).mul(255).to(torch.int8) + all_images = torch.clamp(all_images, 1e-6, 1).mul(255).to(torch.int8) + og_img = torch.clamp(og_img, 1e-6, 1).mul(255).to(torch.int8) + + # 24, 1, 128, 128 + # Calculate SSIM and PSNR, LPIPS + ssims = [] + psnrs = [] + + + + ssim_, psnr_, nmse_ = self.calculate_metrics(og_img, direct_recons) + lpips = self.lpips(direct_recons.float(), og_img.float()).mean().item() + print(f"=== first step Metrics {sampling_routine}: SSIM: ", ssim_, " PSNR: ", psnr_, " LPIPS: ", lpips, " NMSE: ", nmse_) + + for im in all_recons: + ssim_, psnr_, nmse_ = self.calculate_metrics(og_img, im) + ssims.append(ssim_) + psnrs.append(psnr_) + + ssim_, psnr_, nmse_ = self.calculate_metrics(og_img, all_images) + lpips = self.lpips(all_images.float(), og_img.float()).mean().item() + + print(f"=== Final Metrics {sampling_routine}: SSIM: ", ssim_, " PSNR: ", psnr_, " LPIPS: ", lpips, " NMSE: ", nmse_) + + + os.makedirs(self.results_folder, exist_ok=True) + + # utils.save_image(xt, str(self.results_folder / f'{self.step}-xt-Noise.png'), nrow=6) + # utils.save_image(all_images, str(self.results_folder / f'{self.step}-full_recons.png'), + # nrow=6) + # utils.save_image(direct_recons, + # str(self.results_folder / f'{self.step}-sample-direct_recons.png'), nrow=6) + # utils.save_image(og_img, str(self.results_folder / f'{self.step}-img.png'), nrow=6) + # utils.save_image(aux, str(self.results_folder / f'{self.step}-aux.png'), nrow=6) + + # plot ssim and psnr in two sub plots same canvas parallel + import matplotlib.pyplot as plt + fig, axs = plt.subplots(2, figsize=(8, 6), dpi=100, sharex=True) + axs[0].plot(ssims, marker='o', linestyle='-', color='blue', label='SSIM') + axs[0].set_title('Structural Similarity Index (SSIM)', fontsize=14) + axs[0].set_ylabel('SSIM', fontsize=12) + axs[0].grid(True, linestyle='--', alpha=0.7) + axs[0].legend() + + # PSNR plot + axs[1].plot(psnrs, marker='o', linestyle='-', color='green', label='PSNR') + axs[1].set_title('Peak Signal-to-Noise Ratio (PSNR)', fontsize=14) + axs[1].set_xlabel('Iterations', fontsize=12) + axs[1].set_ylabel('PSNR (dB)', fontsize=12) + axs[1].grid(True, linestyle='--', alpha=0.7) + axs[1].legend() + + fig.tight_layout() + + plt.savefig(str(self.results_folder / f'{count}-metrics-{sampling_routine}.png')) + + return_k = return_k.cuda() + + combine = torch.cat((return_k, + xt, + all_images, direct_recons, og_img, aux), 2) + + # utils.save_image(combine, str(self.results_folder / f'{self.step}-combine-{routine}.png'), nrow=6) + + # all_recon = all_recons[:, 0] # 50, 1, 128, 128 + # Ensure all_recons is on the CPU + + all_recons = torch.cat(list(all_recons), dim=-1) + all_masks = all_masks.cpu() + all_masks = torch.cat(list(all_masks), dim=-1) + + s = all_recons.shape[-2] + repeats = all_recons.shape[3] // og_img.shape[3] # Calculate repeat factor + # tensor_small = tensor_small.repeat(1, 1, 1, repeats) + og_img = og_img.cpu() + all_recons_residual = all_recons - og_img.repeat(1, 1, 1, repeats) + # all_recons[:, :, :, s:] + all_recons_residual_2 = all_recons[:, :, :, s:] - all_recons[:, :, :, :-s] + padding = torch.zeros_like(all_recons_residual[:, :, :, :s // 2]) + all_recons_residual_2 = torch.cat([padding, all_recons_residual_2, padding], dim=-1) + + all_recons = torch.cat([all_recons, all_recons_residual_2, all_recons_residual], dim=-2) + + # utils.save_image(all_recons, str(self.results_folder / f'{self.step}-all_recons-{routine}.png'), + # nrow=1) + count += 1 + + + def test_loader_3d(self, sampling_routine): + """ + Computes patient-wise 3D for a dataset of MRI slices. + + """ + print("Starting testing with sampling routine: ", sampling_routine) + + model = self.model # self.ema_model # or self.model + model.eval() # Set model to evaluation mode + + # sampling_routine = ['default', 'x0_step_down', 'x0_step_down_fre', "fre_progressive"]: + num_timesteps = model.module.num_timesteps + patient_wise_pred = {} # num_timesteps + patient_wise_gt = [] + + while True: + batches = 1 #self.batch_size + data_dict = next(self.dl) # .cuda() + if data_dict['is_start']: + patient_wise_pred = {} # num_timesteps of list + for i in range(num_timesteps): + patient_wise_pred[i] = [] + patient_wise_gt = [] + + # Original Input + og_img = data_dict['img'].cuda() + aux = data_dict['aux'].cuda() + img_mean = data_dict['img_mean'].cuda() + img_std = data_dict['img_std'].cuda() + aux_mean = data_dict['aux_mean'].cuda() + aux_std = data_dict['aux_std'].cuda() + params_dict = {'img_mean': img_mean, 'img_std': img_std, + 'aux_mean': aux_mean, 'aux_std': aux_std} + + # Prediction + xt, direct_recons, all_images, return_k, all_recons, all_recons_fre, all_masks = ( + model.module.sample( + batch_size=batches, + faded_recon_sample=og_img, + aux=aux, params_dict=params_dict, + sample_routine=sampling_routine)) + + # Handling the output + img_std = img_std.view(-1, 1, 1, 1) + img_mean = img_mean.view(-1, 1, 1, 1) + aux_std = aux_std.view(-1, 1, 1, 1) + aux_mean = aux_mean.view(-1, 1, 1, 1) + + og_img = og_img * img_std + img_mean + _min = og_img.min() + _max = og_img.max() + + og_img = (og_img - _min) / (_max - _min) + all_images = all_images * img_std + img_mean + all_images = (all_images - _min) / (_max - _min) + + all_recons = all_recons * img_std + img_mean + all_recons = (all_recons - _min) / (_max - _min) + direct_recons = direct_recons * img_std + img_mean + direct_recons = (direct_recons - _min) / (_max - _min) + xt = xt * img_std + img_mean + xt = (xt - _min) / (_max - _min) + + aux = aux * aux_std + aux_mean + aux = (aux - aux.min()) / (aux.max() - aux.min()) + all_recons = all_recons.cpu() + + # Save to list + patient_wise_gt.append(og_img) + for i in range(num_timesteps): + patient_wise_pred[i].append(all_recons[i]) + + if data_dict['is_end']: # or len(patient_wise_gt) >=10: # TODO + # 24, 1, 128, 128 + # Calculate SSIM and PSNR, LPIPS + patient_gt = torch.cat(patient_wise_gt, dim=0).squeeze() # C, H, W + + ssims, psnrs, lpips, fid, fid_3d = [], [], [], [], [] + + for i in range(num_timesteps): + patient_pred = torch.cat(patient_wise_pred[i], dim=0).squeeze() + + ssim_, psnr_, fid_, fid3d_, lpips_ = self.calculate_metrics_3d(patient_gt, patient_pred) + print("time step: ", i, "SSIM: ", ssim_, " PSNR: ", psnr_, " LPIPS: ", lpips_, " FID: ", fid_, " FID_3D: ", fid3d_) + + lpips.append(lpips_) + fid.append(fid_) + fid_3d.append(fid3d_) + ssims.append(ssim_) + psnrs.append(psnr_) + + print(f"=== Final Metrics {sampling_routine}: SSIM: ", ssim_, " PSNR: ", psnr_, " LPIPS: ", lpips_) + + + file_id = data_dict['file_id'][0].split("/")[-1] + + os.makedirs(self.results_folder, exist_ok=True) + + import matplotlib.pyplot as plt + fig, axs = plt.subplots(5, figsize=(8, 6), dpi=100, sharex=True) + axs[0].plot(ssims, marker='o', linestyle='-', color='blue', label='SSIM') + axs[0].set_title('Structural Similarity Index (SSIM)', fontsize=14) + axs[0].set_ylabel('SSIM', fontsize=12) + axs[0].grid(True, linestyle='--', alpha=0.7) + axs[0].legend() + + # PSNR plot + axs[1].plot(psnrs, marker='o', linestyle='-', color='green', label='PSNR') + axs[1].set_title('Peak Signal-to-Noise Ratio (PSNR)', fontsize=14) + axs[1].set_xlabel('Iterations', fontsize=12) + axs[1].set_ylabel('PSNR (dB)', fontsize=12) + axs[1].grid(True, linestyle='--', alpha=0.7) + axs[1].legend() + + # LPIPS plot + axs[2].plot(lpips, marker='o', linestyle='-', color='red', label='LPIPS') + axs[2].set_title('LPIPS', fontsize=14) + axs[2].set_xlabel('Iterations', fontsize=12) + axs[2].set_ylabel('LPIPS', fontsize=12) + axs[2].grid(True, linestyle='--', alpha=0.7) + axs[2].legend() + + # FID plot + axs[3].plot(fid, marker='o', linestyle='-', color='orange', label='FID') + axs[3].set_title('FID', fontsize=14) + axs[3].set_xlabel('Iterations', fontsize=12) + axs[3].set_ylabel('FID', fontsize=12) + axs[3].grid(True, linestyle='--', alpha=0.7) + axs[3].legend() + + axs[4].plot(fid_3d, marker='o', linestyle='-', color='purple', label='FID_3D') + axs[4].set_title('FID_3D', fontsize=14) + axs[4].set_xlabel('Iterations', fontsize=12) + axs[4].set_ylabel('FID_3D', fontsize=12) + axs[4].grid(True, linestyle='--', alpha=0.7) + axs[4].legend() + + fig.tight_layout() + save = str(self.results_folder / f'{file_id}-metrics-{sampling_routine}.png') + plt.savefig(save) + + print("Save metrics to ", save) + + + + + def test_from_data(self, extra_path, s_times=None): + batches = self.batch_size + og_img = next(self.dl).cuda() + x0_list, xt_list = self.ema_model.module.all_sample(batch_size=batches, faded_recon_sample=og_img, times=s_times) + + og_img = (og_img + 1) * 0.5 + utils.save_image(og_img, str(self.results_folder / f'og-{extra_path}.png'), nrow=6) + + frames_t = [] + frames_0 = [] + + for i in range(len(x0_list)): + print(i) + + x_0 = x0_list[i] + x_0 = (x_0 + 1) * 0.5 + utils.save_image(x_0, str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'), nrow=6) + self.add_title(str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'), str(i)) + frames_0.append(imageio.imread(str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'))) + + x_t = xt_list[i] + all_images = (x_t + 1) * 0.5 + utils.save_image(all_images, str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'), nrow=6) + self.add_title(str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'), str(i)) + frames_t.append(imageio.imread(str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'))) + + imageio.mimsave(str(self.results_folder / f'Gif-{extra_path}-x0.gif'), frames_0) + imageio.mimsave(str(self.results_folder / f'Gif-{extra_path}-xt.gif'), frames_t) + + def test_with_mixup(self, extra_path): + batches = self.batch_size + og_img_1 = next(self.dl).cuda() + og_img_2 = next(self.dl).cuda() + og_img = (og_img_1 + og_img_2) / 2 + + x0_list, xt_list = self.ema_model.module.all_sample(batch_size=batches, faded_recon_sample=og_img) + + og_img_1 = (og_img_1 + 1) * 0.5 + utils.save_image(og_img_1, str(self.results_folder / f'og1-{extra_path}.png'), nrow=6) + + og_img_2 = (og_img_2 + 1) * 0.5 + utils.save_image(og_img_2, str(self.results_folder / f'og2-{extra_path}.png'), nrow=6) + + og_img = (og_img + 1) * 0.5 + utils.save_image(og_img, str(self.results_folder / f'og-{extra_path}.png'), nrow=6) + + frames_t = [] + frames_0 = [] + + for i in range(len(x0_list)): + print(i) + x_0 = x0_list[i] + x_0 = (x_0 + 1) * 0.5 + utils.save_image(x_0, str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'), nrow=6) + self.add_title(str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'), str(i)) + frames_0.append(Image.open(str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'))) + + x_t = xt_list[i] + all_images = (x_t + 1) * 0.5 + utils.save_image(all_images, str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'), nrow=6) + self.add_title(str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'), str(i)) + frames_t.append(Image.open(str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'))) + + frame_one = frames_0[0] + frame_one.save(str(self.results_folder / f'Gif-{extra_path}-x0.gif'), format="GIF", append_images=frames_0, + save_all=True, duration=100, loop=0) + + frame_one = frames_t[0] + frame_one.save(str(self.results_folder / f'Gif-{extra_path}-xt.gif'), format="GIF", append_images=frames_t, + save_all=True, duration=100, loop=0) + + def test_from_random(self, extra_path): + batches = self.batch_size + og_img = next(self.dl).cuda() + og_img = og_img * 0.9 + x0_list, xt_list = self.ema_model.module.all_sample(batch_size=batches, faded_recon_sample=og_img) + + og_img = (og_img + 1) * 0.5 + utils.save_image(og_img, str(self.results_folder / f'og-{extra_path}.png'), nrow=6) + + frames_t_names = [] + frames_0_names = [] + + for i in range(len(x0_list)): + print(i) + + x_0 = x0_list[i] + x_0 = (x_0 + 1) * 0.5 + utils.save_image(x_0, str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'), nrow=6) + self.add_title(str(self.results_folder / f'sample-{i}-{extra_path}-x0.png'), str(i)) + frames_0_names.append(str(self.results_folder / f'sample-{i}-{extra_path}-x0.png')) + + x_t = xt_list[i] + all_images = (x_t + 1) * 0.5 + utils.save_image(all_images, str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'), nrow=6) + self.add_title(str(self.results_folder / f'sample-{i}-{extra_path}-xt.png'), str(i)) + frames_t_names.append(str(self.results_folder / f'sample-{i}-{extra_path}-xt.png')) + + frames_0 = [] + frames_t = [] + for i in range(len(x0_list)): + print(i) + frames_0.append(imageio.imread(frames_0_names[i])) + frames_t.append(imageio.imread(frames_t_names[i])) + + imageio.mimsave(str(self.results_folder / f'Gif-{extra_path}-x0.gif'), frames_0) + imageio.mimsave(str(self.results_folder / f'Gif-{extra_path}-xt.gif'), frames_t) + + def controlled_direct_reconstruct(self, extra_path): + batches = self.batch_size + torch.manual_seed(0) + og_img = next(self.dl).cuda() + xt, direct_recons, all_images = self.ema_model.module.sample(batch_size=batches, faded_recon_sample=og_img) + + og_img = (og_img + 1) * 0.5 + utils.save_image(og_img, str(self.results_folder / f'sample-og-{extra_path}.png'), nrow=6) + + all_images = (all_images + 1) * 0.5 + utils.save_image(all_images, str(self.results_folder / f'sample-recon-{extra_path}.png'), nrow=6) + + direct_recons = (direct_recons + 1) * 0.5 + utils.save_image(direct_recons, str(self.results_folder / f'sample-direct_recons-{extra_path}.png'), nrow=6) + + xt = (xt + 1) * 0.5 + utils.save_image(xt, str(self.results_folder / f'sample-xt-{extra_path}.png'), + nrow=6) + + self.save() + + def fid_distance_decrease_from_manifold(self, fid_func, start=0, end=1000): + + all_samples = [] + dataset = self.ds + + print(len(dataset)) + for idx in range(len(dataset)): + img = dataset[idx] + img = torch.unsqueeze(img, 0).cuda() + if idx > start: + all_samples.append(img[0]) + if idx % 1000 == 0: + print(idx) + if end is not None: + if idx == end: + print(idx) + break + + all_samples = torch.stack(all_samples) + blurred_samples = None + original_sample = None + deblurred_samples = None + direct_deblurred_samples = None + + sanity_check = blurred_samples + + cnt = 0 + while cnt < all_samples.shape[0]: + og_x = all_samples[cnt: cnt + 50] + og_x = og_x.cuda() + og_x = og_x.type(torch.cuda.FloatTensor) + og_img = og_x + print(og_img.shape) + x0_list, xt_list = self.ema_model.module.all_sample(batch_size=og_img.shape[0], + faded_recon_sample=og_img, + times=None) + + og_img = og_img.to('cpu') + blurry_imgs = xt_list[0].to('cpu') + deblurry_imgs = x0_list[-1].to('cpu') + direct_deblurry_imgs = x0_list[0].to('cpu') + + og_img = og_img.repeat(1, 3 // og_img.shape[1], 1, 1) + blurry_imgs = blurry_imgs.repeat(1, 3 // blurry_imgs.shape[1], 1, 1) + deblurry_imgs = deblurry_imgs.repeat(1, 3 // deblurry_imgs.shape[1], 1, 1) + direct_deblurry_imgs = direct_deblurry_imgs.repeat(1, 3 // direct_deblurry_imgs.shape[1], 1, 1) + + og_img = (og_img + 1) * 0.5 + blurry_imgs = (blurry_imgs + 1) * 0.5 + deblurry_imgs = (deblurry_imgs + 1) * 0.5 + direct_deblurry_imgs = (direct_deblurry_imgs + 1) * 0.5 + + if cnt == 0: + print(og_img.shape) + print(blurry_imgs.shape) + print(deblurry_imgs.shape) + print(direct_deblurry_imgs.shape) + + if sanity_check: + folder = './sanity_check/' + create_folder(folder) + + san_imgs = og_img[0: 32] + utils.save_image(san_imgs, str(folder + f'sample-og.png'), nrow=6) + + san_imgs = blurry_imgs[0: 32] + utils.save_image(san_imgs, str(folder + f'sample-xt.png'), nrow=6) + + san_imgs = deblurry_imgs[0: 32] + utils.save_image(san_imgs, str(folder + f'sample-recons.png'), nrow=6) + + san_imgs = direct_deblurry_imgs[0: 32] + utils.save_image(san_imgs, str(folder + f'sample-direct-recons.png'), nrow=6) + + if blurred_samples is None: + blurred_samples = blurry_imgs + else: + blurred_samples = torch.cat((blurred_samples, blurry_imgs), dim=0) + + if original_sample is None: + original_sample = og_img + else: + original_sample = torch.cat((original_sample, og_img), dim=0) + + if deblurred_samples is None: + deblurred_samples = deblurry_imgs + else: + deblurred_samples = torch.cat((deblurred_samples, deblurry_imgs), dim=0) + + if direct_deblurred_samples is None: + direct_deblurred_samples = direct_deblurry_imgs + else: + direct_deblurred_samples = torch.cat((direct_deblurred_samples, direct_deblurry_imgs), dim=0) + + cnt += og_img.shape[0] + + print(blurred_samples.shape) + print(original_sample.shape) + print(deblurred_samples.shape) + print(direct_deblurred_samples.shape) + + fid_blur = fid_func(samples=[original_sample, blurred_samples]) + rmse_blur = torch.sqrt(torch.mean((original_sample - blurred_samples) ** 2)) + ssim_blur = ssim(original_sample, blurred_samples, data_range=1, size_average=True) + print(f'The FID of blurry images with original image is {fid_blur}') + print(f'The RMSE of blurry images with original image is {rmse_blur}') + print(f'The SSIM of blurry images with original image is {ssim_blur}') + + fid_deblur = fid_func(samples=[original_sample, deblurred_samples]) + rmse_deblur = torch.sqrt(torch.mean((original_sample - deblurred_samples) ** 2)) + ssim_deblur = ssim(original_sample, deblurred_samples, data_range=1, size_average=True) + print(f'The FID of deblurred images with original image is {fid_deblur}') + print(f'The RMSE of deblurred images with original image is {rmse_deblur}') + print(f'The SSIM of deblurred images with original image is {ssim_deblur}') + + print(f'Hence the improvement in FID using sampling is {fid_blur - fid_deblur}') + + fid_direct_deblur = fid_func(samples=[original_sample, direct_deblurred_samples]) + rmse_direct_deblur = torch.sqrt(torch.mean((original_sample - direct_deblurred_samples) ** 2)) + ssim_direct_deblur = ssim(original_sample, direct_deblurred_samples, data_range=1, size_average=True) + print(f'The FID of direct deblurred images with original image is {fid_direct_deblur}') + print(f'The RMSE of direct deblurred images with original image is {rmse_direct_deblur}') + print(f'The SSIM of direct deblurred images with original image is {ssim_direct_deblur}') + + print(f'Hence the improvement in FID using direct sampling is {fid_blur - fid_direct_deblur}') + + def paper_invert_section_images(self, s_times=None): + + cnt = 0 + for i in range(50): + batches = self.batch_size + og_img = next(self.dl).cuda() + print(og_img.shape) + + x0_list, xt_list = self.ema_model.module.all_sample(batch_size=batches, + faded_recon_sample=og_img, + times=s_times) + og_img = (og_img + 1) * 0.5 + + for j in range(og_img.shape[0]//3): + original = og_img[j: j + 1] + utils.save_image(original, str(self.results_folder / f'original_{cnt}.png'), nrow=3) + + direct_recons = x0_list[0][j: j + 1] + direct_recons = (direct_recons + 1) * 0.5 + utils.save_image(direct_recons, str(self.results_folder / f'direct_recons_{cnt}.png'), nrow=3) + + sampling_recons = x0_list[-1][j: j + 1] + sampling_recons = (sampling_recons + 1) * 0.5 + utils.save_image(sampling_recons, str(self.results_folder / f'sampling_recons_{cnt}.png'), nrow=3) + + blurry_image = xt_list[0][j: j + 1] + blurry_image = (blurry_image + 1) * 0.5 + utils.save_image(blurry_image, str(self.results_folder / f'blurry_image_{cnt}.png'), nrow=3) + + blurry_image = cv2.imread(f'{self.results_folder}/blurry_image_{cnt}.png') + direct_recons = cv2.imread(f'{self.results_folder}/direct_recons_{cnt}.png') + sampling_recons = cv2.imread(f'{self.results_folder}/sampling_recons_{cnt}.png') + original = cv2.imread(f'{self.results_folder}/original_{cnt}.png') + + black = [0, 0, 0] + blurry_image = cv2.copyMakeBorder(blurry_image, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=black) + direct_recons = cv2.copyMakeBorder(direct_recons, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=black) + sampling_recons = cv2.copyMakeBorder(sampling_recons, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=black) + original = cv2.copyMakeBorder(original, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=black) + + im_h = cv2.hconcat([blurry_image, direct_recons, sampling_recons, original]) + cv2.imwrite(f'{self.results_folder}/all_{cnt}.png', im_h) + + cnt += 1 + + def paper_showing_diffusion_images(self, s_times=None): + + cnt = 0 + to_show = [0, 1, 2, 4, 8, 16, 32, 64, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100] + + for i in range(100): + batches = self.batch_size + og_img = next(self.dl).cuda() + print(og_img.shape) + + x0_list, xt_list = self.ema_model.module.all_sample(batch_size=batches, faded_recon_sample=og_img, times=s_times) + + for k in range(xt_list[0].shape[0]): + lst = [] + + for j in range(len(xt_list)): + x_t = xt_list[j][k] + x_t = (x_t + 1) * 0.5 + utils.save_image(x_t, str(self.results_folder / f'x_{len(xt_list)-j}_{cnt}.png'), nrow=1) + x_t = cv2.imread(f'{self.results_folder}/x_{len(xt_list)-j}_{cnt}.png') + if j in to_show: + lst.append(x_t) + + x_0 = x0_list[-1][k] + x_0 = (x_0 + 1) * 0.5 + utils.save_image(x_0, str(self.results_folder / f'x_best_{cnt}.png'), nrow=1) + x_0 = cv2.imread(f'{self.results_folder}/x_best_{cnt}.png') + lst.append(x_0) + im_h = cv2.hconcat(lst) + cv2.imwrite(f'{self.results_folder}/all_{cnt}.png', im_h) + cnt += 1 + + def test_from_data_save_results(self): + batch_size = 100 + dl = data.DataLoader(self.ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=16, + drop_last=True) + + all_samples = None + + for i, img in enumerate(dl, 0): + print(i) + print(img.shape) + if all_samples is None: + all_samples = img + else: + all_samples = torch.cat((all_samples, img), dim=0) + + # break + + # create_folder(f'{self.results_folder}/') + blurred_samples = None + original_sample = None + deblurred_samples = None + direct_deblurred_samples = None + + sanity_check = 1 + + orig_folder = f'{self.results_folder}_orig/' + create_folder(orig_folder) + + blur_folder = f'{self.results_folder}_blur/' + create_folder(blur_folder) + + d_deblur_folder = f'{self.results_folder}_d_deblur/' + create_folder(d_deblur_folder) + + deblur_folder = f'{self.results_folder}_deblur/' + create_folder(deblur_folder) + + cnt = 0 + while cnt < all_samples.shape[0]: + print(cnt) + og_x = all_samples[cnt: cnt + 32] + og_x = og_x.cuda() + og_x = og_x.type(torch.cuda.FloatTensor) + og_img = og_x + x0_list, xt_list = self.ema_model.module.all_sample(batch_size=og_img.shape[0], faded_recon_sample=og_img, times=None) + + og_img = og_img.to('cpu') + blurry_imgs = xt_list[0].to('cpu') + deblurry_imgs = x0_list[-1].to('cpu') + direct_deblurry_imgs = x0_list[0].to('cpu') + + og_img = og_img.repeat(1, 3 // og_img.shape[1], 1, 1) + blurry_imgs = blurry_imgs.repeat(1, 3 // blurry_imgs.shape[1], 1, 1) + deblurry_imgs = deblurry_imgs.repeat(1, 3 // deblurry_imgs.shape[1], 1, 1) + direct_deblurry_imgs = direct_deblurry_imgs.repeat(1, 3 // direct_deblurry_imgs.shape[1], 1, 1) + + og_img = (og_img + 1) * 0.5 + blurry_imgs = (blurry_imgs + 1) * 0.5 + deblurry_imgs = (deblurry_imgs + 1) * 0.5 + direct_deblurry_imgs = (direct_deblurry_imgs + 1) * 0.5 + + if cnt == 0: + print(og_img.shape) + print(blurry_imgs.shape) + print(deblurry_imgs.shape) + print(direct_deblurry_imgs.shape) + + if blurred_samples is None: + blurred_samples = blurry_imgs + else: + blurred_samples = torch.cat((blurred_samples, blurry_imgs), dim=0) + + if original_sample is None: + original_sample = og_img + else: + original_sample = torch.cat((original_sample, og_img), dim=0) + + if deblurred_samples is None: + deblurred_samples = deblurry_imgs + else: + deblurred_samples = torch.cat((deblurred_samples, deblurry_imgs), dim=0) + + if direct_deblurred_samples is None: + direct_deblurred_samples = direct_deblurry_imgs + else: + direct_deblurred_samples = torch.cat((direct_deblurred_samples, direct_deblurry_imgs), dim=0) + + cnt += og_img.shape[0] + + print(blurred_samples.shape) + print(original_sample.shape) + print(deblurred_samples.shape) + print(direct_deblurred_samples.shape) + + for i in range(blurred_samples.shape[0]): + utils.save_image(original_sample[i], f'{orig_folder}{i}.png', nrow=1) + utils.save_image(blurred_samples[i], f'{blur_folder}{i}.png', nrow=1) + utils.save_image(deblurred_samples[i], f'{deblur_folder}{i}.png', nrow=1) + utils.save_image(direct_deblurred_samples[i], f'{d_deblur_folder}{i}.png', nrow=1) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/draw/frequency_sampling.py b/MRI_recon/new_code/Frequency-Diffusion-main/draw/frequency_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..30633cf87161ffff68129e3b823d5f1b0d04f2c2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/draw/frequency_sampling.py @@ -0,0 +1,284 @@ +import torch + +from utils.k_degrade_utils import * + + +if __name__ == "__main__": + # First STEP + import matplotlib.pyplot as plt + import numpy as np, os + + os.makedirs("outputs", exist_ok=True) + + os.makedirs("outputs/low-fre-first", exist_ok=True) + os.makedirs("outputs/random-sample", exist_ok=True) + + + image_size = 256 + accelerated_factor = 6 + center_fraction = 0.04 + time_step = 25 + + + masks = get_ksu_kernel(time_step, image_size, "LogSamplingRate", + accelerated_factor=accelerated_factor, center_fraction=center_fraction) # LogSamplingRate + + + batch_size = 1 + + img = plt.imread("./assets/BraTS20_Training_001_86_t1.png") + img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LINEAR) + img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + + print("input img shape: ", img.shape) + + # to gray scale + if len(img.shape) == 3 and img.shape[-1] == 3: + img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + + + # img = np.transpose(img, (2, 0, 1)) + # img = img[0] + img = np.expand_dims(img, axis=0) + img = torch.from_numpy(img).unsqueeze(0).float() + original_img = img.clone() + + + rand_kernels = [] + rand_x = torch.randint(0, image_size + 1, (batch_size,)).long() + + img = img #* 2 - 1 # + + masked_img = [] + + for m in masks: + m = m.unsqueeze(0) + img = apply_ksu_kernel(img, m) + masked_img.append(img) + + save_masks = masks + masks = np.concatenate(masks, axis=-1)[0] + masked_img = torch.concat(masked_img, dim=-1).numpy() #+ 1) * 0.5 + + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + # masked_img = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY) + + + + img = np.concatenate([masks, masked_img], axis=0) + min_ = masked_img.min() + max_ = masked_img.max() + + out = img[image_size: 2 * image_size, : image_size] + fft, _ = apply_tofre(torch.from_numpy(out), torch.from_numpy(out)) # complex + fft = np.abs(fft.numpy()) + fft = np.log(fft) + fft = (fft - fft.min()) / (fft.max() - fft.min()) + + for i in range(time_step+1): + out = img[image_size: 2 * image_size, i * image_size: (i + 1) * image_size] + + # out = (out - out.min()) / (out.max() - out.min()) + out = (out - min_) / (max_ - min_) + plt.imsave(f"outputs/low-fre-first/{i}_image.png", out, cmap='gray') + + if i != 0: + out = img[:image_size, i * image_size:(i + 1) * image_size] + out = (out - out.min()) / (out.max() - out.min()) + + plt.imsave(f"outputs/low-fre-first/{i}_mask.png", out, cmap='gray') + + save_fft = fft * out + plt.imsave(f"outputs/low-fre-first/{i}_fft.png", save_fft, cmap='gray') + + + else: + diff = np.ones((image_size, image_size, 3), dtype=np.uint8) * 255 # All 255 (White)ve + ones = diff.astype(np.float32) / 255.0 + print("ones shape: ", ones.shape, ones.min(), ones.max()) + + plt.imsave(f"outputs/low-fre-first/{i}_mask.png", ones, cmap='gray') + plt.imsave(f"outputs/low-fre-first/{i}_fft.png", fft, cmap='gray') + + try: + diff = img[:image_size, (i-1) * image_size:(i) * image_size] - \ + img[:image_size, (i) * image_size:(i + 1) * image_size] + + except: + diff = np.zeros_like(img[:image_size, : image_size]) + + # plt.imsave(f"outputs/low-fre-first/{i}_mask_diff.png", diff, cmap='gray') + # print("diff shape: ", diff.shape, diff.min(), diff.max()) + + + diffsig = diff * fft + # save it as a red img, but the bg is trasparent + alpha_channel = np.full_like(diff, 255, dtype=np.uint8) * diff + alpha_channel = np.expand_dims(alpha_channel, axis=-1) + + diff = (diff * 255).astype(np.uint8) + diff = np.stack([diff, np.zeros_like(diff), np.zeros_like(diff)], axis=-1) + # Create an alpha channel (255 for full opacity) + + # Concatenate RGB with Alpha channel + diff = np.concatenate([diff, alpha_channel], axis=-1) + diff = diff.astype(np.uint8) + + # print("diff shape: ", diff.shape, diff.min(), diff.max()) + + + plt.imsave(f"outputs/low-fre-first/{i}_mask_diff_red.png", diff, cmap='gray') + plt.imsave(f"outputs/low-fre-first/{i}_mask_diffsig.png", diffsig, cmap='gray') + + + plt.imsave("outputs/masked_img.png", masked_img, cmap='gray') + plt.figure(figsize=(5*time_step, 10)) + plt.imshow(img, cmap='gray') # (1, 128, 1280) + plt.show() + + print("\n\nSecond stage...") + + # ------------------------------- ------------------------------- ------------------------------- + # ------------------------------- ------------------------------- ------------------------------- + + # Second STEP completely Random + import matplotlib.pyplot as plt + import numpy as np + + + final_mask = save_masks[-1][0].numpy() + new_masks = [] + + plt.imshow(final_mask, cmap='gray') + plt.show() + + height, width = final_mask.shape + print("final_mask shape: ", final_mask.shape) + + # Count ones and zeros + ones = np.sum(final_mask[0] == 1) + zeros = np.sum(final_mask[0] == 0) + + print("Initial ones count:", ones) + print("Initial zeros count:", zeros) + + # Identify initially filled and empty strips + initial_filled_indices = np.where(final_mask[0] == 1)[0] + remaining_indices = np.where(final_mask[0] == 0)[0] + + # Shuffle remaining indices to randomize filling order + np.random.shuffle(remaining_indices) + + # Split remaining indices into `time_step` parts + fills_per_step = np.array_split(remaining_indices, time_step) + + masked_img = [] # Store masks at each step + + # Copy initial mask + current_mask = final_mask.copy() + new_masks.append(current_mask.copy()) # Store initial state + + # Fill remaining strips over time + for i in range(time_step): + current_mask[:, fills_per_step[i - 1]] = 1 # Fill new strips + new_masks.append(current_mask.copy()) # Store new mask + # current_mask.append(final_mask) # Store new mask + + new_masks = new_masks[::-1] # Reverse list to get correct order + masked_img = [] + + for m in new_masks: + m = torch.from_numpy(m) #.unsqueeze(0) + + img = apply_ksu_kernel(original_img, m) + masked_img.append(img) + + masks = np.concatenate(new_masks, axis=-1) + masked_img = torch.concat(masked_img, dim=-1).numpy() #+ 1) * 0.5 + + masked_img = np.transpose(masked_img, (0, 2, 3, 1))[0, ..., 0] + + print("masked_img shape: ", masked_img.shape) + + + # masked_img = cv2.cvtColor(masked_img, cv2.COLOR_RGB2GRAY) + # masked_img = (masked_img - masked_img.min()) / (masked_img.max() - masked_img.min()) + + img = np.concatenate([masks, masked_img], axis=0) + + + min_ = masked_img.min() + max_ = masked_img.max() + + out = img[image_size: 2 * image_size, : image_size] + fft, _ = apply_tofre(torch.from_numpy(out), torch.from_numpy(out)) # complex + fft = np.abs(fft.numpy()) + fft = np.log(fft) + fft = (fft - fft.min()) / (fft.max() - fft.min()) + + + for i in range(time_step+1): + # if i % 3 != 0: + # continue + out = img[image_size : 2*image_size, i * image_size : (i + 1) * image_size] + + # out = (out - out.min()) / (out.max() - out.min()) + out = (out - min_) / (max_ - min_) + plt.imsave(f"outputs/random-sample/{i}_image.png", out, cmap='gray') + + if i != 0: + out = img[:image_size, i * image_size:(i + 1) * image_size] + out = (out - out.min()) / (out.max() - out.min()) + + plt.imsave(f"outputs/random-sample/{i}_mask.png", out, cmap='gray') + + save_fft = fft * out + plt.imsave(f"outputs/random-sample/{i}_fft.png", save_fft, cmap='gray') + + noise = np.random.normal(0, 0.2*np.log((time_step-i)+1), out.shape) * fft + save_fft = fft + noise * (1-out) # Sigma + plt.imsave(f"outputs/random-sample/{i}_fft_reverse.png", save_fft, cmap='gray') + + + else: + ones = np.ones_like(out) * 255 + plt.imsave(f"outputs/random-sample/{i}_mask.png", ones, cmap='gray') + plt.imsave(f"outputs/random-sample/{i}_fft.png", fft, cmap='gray') + + + try: + diff = img[:image_size, (i-1) * image_size:(i) * image_size] - \ + img[:image_size, (i) * image_size:(i + 1) * image_size] + + except: + diff = np.zeros_like(img[:image_size, : image_size]) + + + plt.imsave(f"outputs/random-sample/{i}_mask_diff.png", diff, cmap='gray') + # print("diff shape: ", diff.shape, diff.min(), diff.max()) + + # save it as a red img, but the bg is trasparent + alpha_channel = np.full_like(diff, 255, dtype=np.uint8) * diff + alpha_channel = np.expand_dims(alpha_channel, axis=-1) + + diff = (diff * 255).astype(np.uint8) + diff = np.stack([diff, np.zeros_like(diff), np.zeros_like(diff)], axis=-1) + # Create an alpha channel (255 for full opacity) + + # Concatenate RGB with Alpha channel + diff = np.concatenate([diff, alpha_channel], axis=-1) + diff = diff.astype(np.uint8) + + print("diff shape: ", diff.shape, diff.min(), diff.max()) + + plt.imsave(f"outputs/random-sample/{i}_mask_diff_red.png", diff, cmap='gray') + + + plt.imsave("outputs/img.png", img, cmap='gray') + # plt.figure(figsize=(5*time_step, 10)) + + plt.imshow(img, cmap='gray') # (1, 128, 1280) + plt.tight_layout() + plt.show() + + print("\n\nSecond stage...") diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/draw/utils/k_degrade_utils.py b/MRI_recon/new_code/Frequency-Diffusion-main/draw/utils/k_degrade_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9991a66043ffd35ae4cac7fd64f8d4780f7e97f6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/draw/utils/k_degrade_utils.py @@ -0,0 +1,312 @@ +# --------------------------- +# Fade kernels +# --------------------------- +import cv2, torch +import numpy as np +import torchgeometry as tgm +from torch.fft import fft2, ifft2, fftshift, ifftshift, fftn, ifftn +import sys, os + +from utils.mask_utils import RandomMaskFunc, EquispacedMaskFunc + + +from torch import nn +import matplotlib.pyplot as plt + +def get_fade_kernel(dims, std): + fade_kernel = tgm.image.get_gaussian_kernel2d(dims, std) + fade_kernel = fade_kernel / torch.max(fade_kernel) + fade_kernel = torch.ones_like(fade_kernel) - fade_kernel + # if device_of_kernel == 'cuda': + # fade_kernel = fade_kernel.cuda() + fade_kernel = fade_kernel[1:, 1:] + return fade_kernel + + + +def get_fade_kernels(fade_routine, num_timesteps, image_size, kernel_std,initial_mask): + kernels = [] + for i in range(num_timesteps): + if fade_routine == 'Incremental': + kernels.append(get_fade_kernel((image_size + 1, image_size + 1), + (kernel_std * (i + initial_mask), + kernel_std * (i + initial_mask)))) + elif fade_routine == 'Constant': + kernels.append(get_fade_kernel( + (image_size + 1, image_size + 1), + (kernel_std, kernel_std))) + + elif fade_routine == 'Random_Incremental': + kernels.append(get_fade_kernel((2 * image_size + 1, 2 * image_size + 1), + (kernel_std * (i + initial_mask), + kernel_std * (i + initial_mask)))) + return torch.stack(kernels) + + +# --------------------------- +# Kspace kernels +# --------------------------- +# cartesian_regular +def get_mask_func(mask_method, af, cf): + if mask_method == 'cartesian_regular': + return EquispacedMaskFractionFunc(center_fractions=[cf], accelerations=[af]) + + elif mask_method == 'cartesian_random': + return RandomMaskFunc(center_fractions=[cf], accelerations=[af]) + + elif mask_method == "random": + return RandomMaskFunc([cf], [af]) + + elif mask_method == "randompatch": + return RandomPatchFunc([cf], [af]) + + elif mask_method == "equispaced": + return EquispacedMaskFunc([cf], [af]) + + else: + raise NotImplementedError + + +use_fix_center_ratio = False + +class Noisy_Patch(nn.Module): + def __init__(self): + super(Noisy_Patch, self).__init__() + self.af_list = [] + self.cf_list = [] + self.fe_list = [] + self.pe_list = [] + self.seed = 0 + + def append_list(self, at, cf, fe, pe): + self.af_list.append(at) + self.cf_list.append(cf) + self.fe_list.append(fe) + self.pe_list.append(pe) + + def get_noisy_patches(self, t): + af = self.af_list[t] + cf = self.cf_list[t] + fe = self.fe_list[t] + pe = self.pe_list[t] + + patch_mask = get_mask_func("randompatch", af, cf) + mask_, _ = patch_mask((fe, pe, 1), seed=self.seed) # mask (numpy): (fe, pe) + return mask_ + + def forward(self, mask, ts): + # Step 1, Random Drop elements to learn intra-stride effects, patch-wise + # print("use_patch_kernel forward:", t) + # print("mask = ", mask.shape) + # masks_ = [] + for id, t in enumerate(ts): + mask_ = self.get_noisy_patches(t)[0] + # print("mask_ = ", mask_.shape) + # print("mask[id, t] =", mask[t].shape) + + mask[t] = mask_.to(mask[t].device) * mask[t] + self.seed += ts[0].item() + + # masks_ = torch.stack(masks_).cuda() + # print("masks_ = ", masks_.shape) + # print("mask = ", mask.shape) # B, T, H, W + + return mask + +get_noisy_patches = Noisy_Patch() + + +def get_ksu_mask(mask_method, af, cf, pe, fe, seed=0, is_training=False, sort_center=True): + mask_func = get_mask_func(mask_method, af, cf) # acceleration factor, center fraction + + if mask_method in ['cartesian_regular', 'cartesian_random', 'equispaced']: + print("pe:", pe) + mask = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + # prepare noisy patches + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, fe, pe) + + elif mask_method == 'equispaced': + mask = mask_func((1, pe, 1), seed=seed) # mask (torch): (1, pe, 1) + # Extent the stride + mask = mask.permute(0, 2, 1).repeat(1, fe, 1) # mask (torch): (1, pe, 1) --> (1, 1, pe) --> (1, fe, pe) + + # prepare noisy patches + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, fe, pe) + + elif mask_method == 'gaussian_2d': + mask, _ = mask_func(resolution=(fe, pe), accel=af, sigma=100, seed=seed) # mask (numpy): (fe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (fe, pe) --> (1, fe, pe) + + elif mask_method in ['radial_add', 'radial_sub', 'spiral_add', 'spiral_sub']: + sr = 1 / af + mask = mask_func(mask_type=mask_method, mask_sr=sr, res=pe, seed=seed) # mask (numpy): (pe, pe) + mask = torch.from_numpy(mask[np.newaxis, :, :]) # mask (torch): (pe, pe) --> (1, pe, pe) + + else: + raise NotImplementedError + + # print("return mask = ", mask.shape) + return mask + + + +def get_ksu_kernel(timesteps, image_size, + ksu_routine="LogSamplingRate", + accelerated_factor=4, center_fraction=0.08, accelerate_mask=None, sort_center=True): + + if accelerated_factor == 4: + mask_method, center_fraction = "cartesian_random", center_fraction #0.08 # 0.15 + + else: + mask_method, center_fraction = "equispaced", center_fraction # 0.04 + + + center_ratio_factor = center_fraction * accelerated_factor + + masks = [] + noisy_masks = [] + ksu_mask_pe = ksu_mask_fe = image_size # , ksu_mask_pe=320, ksu_mask_fe=320 + # ksu_mask_fe + if ksu_routine == 'LinearSamplingRate': + # Generate the sampling rate list with torch.linspace, reversed, and skip the first element + sr_list = torch.linspace(start=1/accelerated_factor, end=1, steps=timesteps + 1).flip(0) + sr_list = [sr.item() for sr in sr_list] + # Start from 0.01 + for sr in sr_list: + # sr = sr.item() + af = 1 / sr # * accelerated_factor # acceleration factor + cf = center_fraction if use_fix_center_ratio else sr_list[0] * center_ratio_factor + + masks.append(get_ksu_mask(mask_method, af, cf, pe=ksu_mask_pe, fe=ksu_mask_fe)) + + elif ksu_routine == 'LogSamplingRate': + + # Generate the sampling rate list with torch.logspace, reversed, and skip the first element + sr_list = torch.logspace(start=-torch.log10(torch.tensor(accelerated_factor)), + end=0, steps=timesteps + 1).flip(0) + + sr_list = [sr.item() for sr in sr_list] + af = 1 / sr_list[-1] + cf = center_fraction if use_fix_center_ratio else sr_list[-1] * center_ratio_factor + + + if isinstance(accelerate_mask, type(None)): + cache_mask = get_ksu_mask(mask_method, af, cf, pe=ksu_mask_pe, fe=ksu_mask_fe, sort_center=sort_center) + print("cache_mask = ", cache_mask.shape) # torch.Size([1, 320, 320]) + else: + cache_mask = accelerate_mask + + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, ksu_mask_fe, ksu_mask_pe) + + masks.append(cache_mask) + + sr_list = sr_list[:-1][::-1] #.flip(0) # Flip? + + for sr in sr_list: + af = 1 / sr + cf = center_fraction if use_fix_center_ratio else sr * center_ratio_factor + # print("af = ", af, cf) + + H, W = cache_mask.shape[1], cache_mask.shape[2] + new_mask = cache_mask.clone() + + # Add additional lines to the mask based on new acceleration factor + total_lines = H + sampled_lines = int(total_lines / af) + existing_lines = new_mask.squeeze(0).sum(dim=0).nonzero(as_tuple=True)[0].tolist() + + remaining_lines = [i for i in range(total_lines) if i not in existing_lines] + + if sampled_lines > len(existing_lines): + center = W // 2 + additional_lines = sampled_lines - len(existing_lines) # sample number + + sorted_indices = sorted(remaining_lines, key=lambda x: abs(x - center)) + + # Take the closest `additional_lines` indices + sampled_indices = sorted_indices[:additional_lines] + + # Remove sampled indices from remaining_lines + for idx in sampled_indices: + remaining_lines.remove(idx) + + # Update new_mask for each sampled index + for idx in sampled_indices: + new_mask[:, :, idx] = 1.0 + + + + cache_mask = new_mask + + af_new = 1.0 + (af - 1.0) / 2 + get_noisy_patches.append_list(af_new, cf, ksu_mask_fe, ksu_mask_pe) + + + masks.append(cache_mask) + + # reverse + masks = masks[::-1] + noisy_masks = masks # noisy_masks[::-1] + + + elif mask_method == 'gaussian_2d': + raise NotImplementedError("Gaussian 2D mask type is not implemented.") + + else: + raise NotImplementedError(f'Unknown k-space undersampling routine {ksu_routine}') + + # Return masks, excluding the first one + return masks + + + +class high_fre_mask: + def __init__(self): + self.mask_cache = {} + + def __call__(self, H, W): + if (H, W) in self.mask_cache: + return self.mask_cache[(H, W)] + center_x, center_y = H // 2, W // 2 + radius = H//8 # 影响的频率范围半径 + + high_freq_mask = torch.ones(H, W) + for i in range(H): + for j in range(W): + if (i - center_x) ** 2 + (j - center_y) ** 2 <= radius ** 2: + high_freq_mask[i, j] = 0.0 + self.mask_cache[(H, W)] = high_freq_mask + return high_freq_mask + + +high_fre_mask_cls = high_fre_mask() + + + +def apply_ksu_kernel(x_start, mask): + fft, mask = apply_tofre(x_start, mask) + fft = fft * mask + x_ksu = apply_to_spatial(fft) + + return x_ksu + +# from dataloaders.math import ifft2c, fft2c, complex_abs + +def apply_tofre(x_start, mask): + # B, C, H, W = x_start.shape + kspace = fftshift(fft2(x_start, norm=None, dim=(-2, -1)), dim=(-2, -1)) # Default: all dimensions + mask = mask.to(kspace.device) + return kspace, mask + +def apply_to_spatial(fft): + x_ksu = ifft2(ifftshift(fft, dim=(-2, -1)), norm=None, dim=(-2, -1)) # ortho + # After ifftn, the output is already in the spatial domain + x_ksu = x_ksu.real #torch.abs(x_ksu) # + return x_ksu + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/draw/utils/mask_utils.py b/MRI_recon/new_code/Frequency-Diffusion-main/draw/utils/mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..43b2d74d38a5add14c9815b3a883b53f7e49a0fc --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/draw/utils/mask_utils.py @@ -0,0 +1,195 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/LICENSE b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/README.md b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8f9aaadef4dd0210e6f11eb09f082c241e08051e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/README.md @@ -0,0 +1,97 @@ +# FSMNet +FSMNet efficiently explores global dependencies across different modalities. Specifically, the features for each modality are extracted by the Frequency-Spatial Feature Extraction (FSFE) module, featuring a frequency branch and a spatial branch. Benefiting from the global property of the Fourier transform, the frequency branch can efficiently capture global dependency with an image-size receptive field, while the spatial branch can extract local features. To exploit complementary information from the auxiliary modality, we propose a Cross-Modal Selective fusion (CMS-fusion) module that selectively incorporate the frequency and spatial features from the auxiliary modality to enhance the corresponding branch of the target modality. To further integrate the enhanced global features from the frequency branch and the enhanced local features from the spatial branch, we develop a Frequency-Spatial fusion (FS-fusion) module, resulting in a comprehensive feature representation for the target modality. + +

+ +## Paper + +Accelerated Multi-Contrast MRI Reconstruction via Frequency and Spatial Mutual Learning
+[Qi Chen](https://scholar.google.com/citations?user=4Q5gs2MAAAAJ&hl=en)1, [Xiaohan Xing](https://hathawayxxh.github.io/)2, *, [Zhen Chen](https://franciszchen.github.io/)3, [Zhiwei Xiong](http://staff.ustc.edu.cn/~zwxiong/)1
+1 University of Science and Technology of China,
+2 Stanford University,
+3 Centre for Artificial Intelligence and Robotics (CAIR), HKISI-CAS
+MICCAI, 2024
+[paper](http://arxiv.org/abs/2409.14113) | [code](https://github.com/qic999/FSMNet) | [huggingface](https://huggingface.co/datasets/qicq1c/MRI_Reconstruction) + +## 0. Installation + +```bash +git clone https://github.com/qic999/FSMNet.git +cd FSMNet +``` + +See [installation instructions](documents/INSTALL.md) to create an environment and obtain requirements. + +## 1. Prepare datasets +Download BraTS dataset and fastMRI dataset and save them to the `datapath` directory. +``` +cd $datapath +# download brats dataset +wget https://huggingface.co/datasets/qicq1c/MRI_Reconstruction/resolve/main/BRATS_100patients.zip +unzip BRATS_100patients.zip +# download fastmri dataset +wget https://huggingface.co/datasets/qicq1c/MRI_Reconstruction/resolve/main/singlecoil_train_selected.zip +unzip singlecoil_train_selected.zip +``` + +## 2. Training +##### BraTS dataset, AF=4 +``` +python train_brats.py --root_path /data/qic99/MRI_recon image_100patients_4X/ \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --MRIDOWN 4X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_4x +``` + +##### BraTS dataset, AF=8 +``` +python train_brats.py --root_path /data/qic99/MRI_recon/image_100patients_8X/ \ + --gpu 1 --batch_size 4 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_8x +``` + +##### fastMRI dataset, AF=4 +``` +python train_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x +``` + +##### fastMRI dataset, AF=8 +``` +python train_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 1 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x +``` + +## 3. Testing +##### BraTS dataset, AF=4 +``` +python test_brats.py --root_path /data/qic99/MRI_recon/image_100patients_4X/ \ + --gpu 3 --base_lr 0.0001 --MRIDOWN 4X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_4x --phase test +``` + +##### BraTS dataset, AF=8 +``` +python test_brats.py --root_path /data/qic99/MRI_recon/image_100patients_8X/ \ + --gpu 4 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_8x --phase test +``` + +##### fastMRI dataset, AF=4 +``` +python test_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 5 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x --phase test +``` + +##### fastMRI dataset, AF=8 +``` +python test_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 6 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x --phase test +``` \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/bash/brats.sh b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/bash/brats.sh new file mode 100644 index 0000000000000000000000000000000000000000..f0925f21f2c63b2dac30a07f7bbcd9f07e20abdc --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/bash/brats.sh @@ -0,0 +1,37 @@ +# BraTS dataset, AF=4 +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion/experiments/FSMNet + +#root_path=/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee/image_100patients_4X/ +root_path=/gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_4X/ + +python train_brats.py --root_path $root_path\ + --gpu 0 --batch_size 4 --base_lr 0.0001 --MRIDOWN 4X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_4x + +BraTS dataset, AF=8 + +python train_brats.py --root_path /gamedrive/Datasets/medical/FrequencyDiffusion/brats/image_100patients_8X/ \ + --gpu 1 --batch_size 4 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_8x + + + + +# Test +BraTS dataset, AF=4 + +python test_brats.py --root_path /data/qic99/MRI_recon/image_100patients_4X/ \ + --gpu 3 --base_lr 0.0001 --MRIDOWN 4X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_4x --phase test + +BraTS dataset, AF=8 + +python test_brats.py --root_path /data/qic99/MRI_recon/image_100patients_8X/ \ + --gpu 4 --base_lr 0.0001 --MRIDOWN 8X --low_field_SNR 0 \ + --input_normalize mean_std \ + --exp FSMNet_BraTS_8x --phase test + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/bash/fastmri.sh b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/bash/fastmri.sh new file mode 100644 index 0000000000000000000000000000000000000000..01c570029ae19591c104c360e0b5f2ae87292532 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/bash/fastmri.sh @@ -0,0 +1,32 @@ +# fastMRI dataset, AF=4 +# BraTS dataset, AF=4 +mamba activate diffmri +cd /home/cbtil3/hao/repo/Frequency-Diffusion/experiments/FSMNet + +data_root=/bask/projects/j/jiaoj-rep-learn/Hao/datasets/knee +#data_root=/gamedrive/Datasets/medical/FrequencyDiffusion + + +python train_fastmri.py --root_path $data_root \ + --gpu 0 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x + +# fastMRI dataset, AF=8 + +python train_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 1 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x + + +# Test +fastMRI dataset, AF=4 + +python test_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 5 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.08 --ACCELERATIONS 4 \ + --exp FSMNet_fastmri_4x --phase test + +fastMRI dataset, AF=8 + +python test_fastmri.py --root_path /data/qic99/MRI_recon/fastMRI/ \ + --gpu 6 --batch_size 4 --base_lr 0.0001 --CENTER_FRACTIONS 0.04 --ACCELERATIONS 8 \ + --exp FSMNet_fastmri_8x --phase test diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/BRATS_DuDo_dataloader.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/BRATS_DuDo_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..e2b06691ee683a347d4a20948d03598db65e9c08 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/BRATS_DuDo_dataloader.py @@ -0,0 +1,295 @@ +""" +dual-domain network的dataloader, 读取两个模态的under-sampled和fully-sampled kspace data, 以及high-quality image作为监督信号。 +""" + + +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import cv2 +import os +import torch +from torch.utils.data import Dataset + +from .kspace_subsample import undersample_mri, mri_fft + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + + return data + + + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, HF_refine = 'False', split='train', MRIDOWN='4X', SNR=15, \ + transform=None, input_round = None, input_normalize=None): + + super().__init__() + self._base_dir = base_dir + self.HF_refine = HF_refine + self.input_round = input_round + self._MRIDOWN = MRIDOWN + self._SNR = SNR + self.im_ids = [] + self.t2_images = [] + self.splits_path = "/data/xiaohan/BRATS_dataset/cv_splits_100patients/" + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + # self.test_file = self.splits_path + 'train_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + # test_images = os.listdir(self._base_dir) + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + self.t2_images.append(t2_path) + + + self.transform = transform + self.input_normalize = input_normalize + + assert (len(self.t1_images) == len(self.t2_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + + image_name = self.t1_images[index].split('t1')[0] + # print("image name:", image_name) + + t1 = np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0 + t2 = np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0 + # print("images:", t1_in.shape, t1.shape, t2_in.shape, t2.shape) + # print("t1 before standardization:", t1.max(), t1.min(), t1.mean()) + # print("loaded t1 range:", t1.max(), t1.min()) + # print("loaded t2 range:", t2.max(), t2 .min()) + + ### normalize the MRI image by divide_max + t1_max, t2_max = t1.max(), t2.max() + t1 = t1/t1_max + t2 = t2/t2_max + sample_stats = {"t1_max": t1_max, "t2_max": t2_max, "image_name": image_name} + + # sample_stats = {"t1_max": 1.0, "t2_max": 1.0} + + ### convert images to kspace and perform undersampling. + t1_kspace_in, t1_in, t1_kspace, t1_img = mri_fft(t1, _SNR = self._SNR) + t2_kspace_in, t2_in, t2_kspace, t2_img, mask = undersample_mri( + t2, _MRIDOWN = self._MRIDOWN, _SNR = self._SNR) + + + # print("loaded t2 range:", t2.max(), t2.min()) + # print("t2_under_img range:", t2_under_img.max(), t2_under_img.min()) + # print("t2_kspace real_part range:", t2_kspace.real.max(), t2_kspace.real.min()) + # print("t2_kspace imaginary_part range:", t2_kspace.imag.max(), t2_kspace.imag.min()) + # print("t2_kspace_in real_part range:", t2_kspace_in.real.max(), t2_kspace_in.real.min()) + # print("t2_kspace_in imaginary_part range:", t2_kspace_in.imag.max(), t2_kspace_in.imag.min()) + + if self.HF_refine == "False": + sample = {'t1': t1_img, 't1_in': t1_in, 't1_kspace': t1_kspace, 't1_kspace_in': t1_kspace_in, \ + 't2': t2_img, 't2_in': t2_in, 't2_kspace': t2_kspace, 't2_kspace_in': t2_kspace_in, \ + 't2_mask': mask} + + elif self.HF_refine == "True": + ### 读取上一步重建的kspace data. + t1_krecon_path = self._base_dir + self.t1_images[index].replace( + 't1.png', 't1_' + str(self._SNR) + 'dB_recon_kspace_' + self.input_round + '_DudoLoss.npy') + t2_krecon_path = self._base_dir + self.t1_images[index].replace('t1.png', 't2_' + self._MRIDOWN + \ + '_' + str(self._SNR) + 'dB_recon_kspace_' + self.input_round + '_DudoLoss.npy') + + t1_krecon = np.load(t1_krecon_path) + t2_krecon = np.load(t2_krecon_path) + # print("t1 and t2 recon kspace:", t1_krecon.shape, t2_krecon.shape) + # + sample = {'t1': t1_img, 't1_in': t1_in, 't1_kspace': t1_kspace, 't1_kspace_in': t1_kspace_in, \ + 't2': t2_img, 't2_in': t2_in, 't2_kspace': t2_kspace, 't2_kspace_in': t2_kspace_in, \ + 't2_mask': mask, 't1_krecon': t1_krecon, 't2_krecon': t2_krecon} + + + if self.transform is not None: + sample = self.transform(sample) + + return sample, sample_stats + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + # print("img_in before_numpy range:", img_in.max(), img_in.min()) + img = torch.from_numpy(img).float() + target = torch.from_numpy(target).float() + # print("img_in range:", img_in.max(), img_in.min()) + + return {'ct': img, 'mri': target} + + +# class ToTensor(object): +# """Convert ndarrays in sample to Tensors.""" + +# def __call__(self, sample): +# # swap color axis because +# # numpy image: H x W x C +# # torch image: C X H X W +# img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) +# img = sample['image'][:, :, None].transpose((2, 0, 1)) +# target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) +# target = sample['target'][:, :, None].transpose((2, 0, 1)) +# # print("img_in before_numpy range:", img_in.max(), img_in.min()) +# img_in = torch.from_numpy(img_in).float() +# img = torch.from_numpy(img).float() +# target_in = torch.from_numpy(target_in).float() +# target = torch.from_numpy(target).float() +# # print("img_in range:", img_in.max(), img_in.min()) + +# return {'ct_in': img_in, +# 'ct': img, +# 'mri_in': target_in, +# 'mri': target} diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/BRATS_dataloader.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/BRATS_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..52523fb1e080166812a64c191f92884cee244219 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/BRATS_dataloader.py @@ -0,0 +1,175 @@ +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import os +import torch +from torch.utils.data import Dataset + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', MRIDOWN='4X', SNR=15, transform=None): + + super().__init__() + self._base_dir = base_dir + self._MRIDOWN = MRIDOWN + self.im_ids = [] + self.t2_images = [] + self.t1_undermri_images, self.t2_undermri_images = [], [] + self.splits_path = "/home/xiaohan/datasets/BRATS_dataset/BRATS_2020_images/cv_splits/" + + + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + # test_images = os.listdir(self._base_dir) + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + if SNR == 0: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_undermri') + t1_under_path = image_path + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + else: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + t1_under_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + if MRIDOWN == "False": + t2_under_path = image_path.replace('t1', 't2_' + str(SNR) + 'dB') + else: + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + + # print("image paths:", image_path, t1_under_path, t2_path, t2_under_path) + + self.t2_images.append(t2_path) + self.t1_undermri_images.append(t1_under_path) + self.t2_undermri_images.append(t2_under_path) + + # print("t1 images:", self.t1_images) + # print("t2 images:", self.t2_images) + # print("t1_undermri_images:", self.t1_undermri_images) + # print("t2_undermri_images:", self.t2_undermri_images) + + self.transform = transform + + assert (len(self.t1_images) == len(self.t2_images)) + assert (len(self.t1_images) == len(self.t1_undermri_images)) + assert (len(self.t1_images) == len(self.t2_undermri_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + ### 两种settings. + ### 1. T1 fully-sampled 不加noise, T2 down-sampled, 做MRI acceleration. + ### 2. T1 fully-sampled 但是加noise, T2 down-sampled同时也加noise, 同时做MRI acceleration and enhancement. + ### T1, T2两个模态的输入都是low-quality images. + sample = {'image_in': np.array(Image.open(self._base_dir + self.t1_undermri_images[index]))/255.0, + 'image': np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0, + 'target_in': np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0, + 'target': np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0} + + + # ### 2023/05/23, Xiaohan, 把T1模态的输入改成high-quality图像(和ground truth一致,看能否为T2提供更好的guidance)。 + # sample = {'image_in': np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0, + # 'image': np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0, + # 'target_in': np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0, + # 'target': np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0} + + if self.transform is not None: + sample = self.transform(sample) + + return sample + + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 256, 256 + crop_size = 240 + pad_size = (256-240)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + # print("img_in:", img_in.shape) + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + return {'ct_in': img_in, + 'ct': img, + 'mri_in': target_in, + 'mri': target} diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/BRATS_dataloader_new.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/BRATS_dataloader_new.py new file mode 100644 index 0000000000000000000000000000000000000000..288e448bd06ffd5fd94e253e742dd29ed253ab34 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/BRATS_dataloader_new.py @@ -0,0 +1,371 @@ +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import cv2 +import os +import torch +from torch.utils.data import Dataset +from torchvision import transforms + + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', MRIDOWN='4X', \ + SNR=15, transform=None, input_normalize=None): + + super().__init__() + self._base_dir = base_dir + self._MRIDOWN = MRIDOWN + self.im_ids = [] + self.t2_images = [] + self.t1_undermri_images, self.t2_undermri_images = [], [] + self.t1_krecon_images, self.t2_krecon_images = [], [] + self.kspace_refine = "False" # ADD + + + name = base_dir.rstrip("/ ").split('/')[-1] + print("base_dir=", base_dir, ", folder name =", name) + self.splits_path = base_dir.replace(name, 'cv_splits_100patients/') + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + + if SNR == 0: + t1_under_path = image_path + + if self.kspace_refine == "False": + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + elif self.kspace_refine == "True": + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_krecon') + + if self.kspace_refine == "False": + t1_krecon_path = image_path + t2_krecon_path = image_path + + # if SNR == 0: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_undermri') + t1_under_path = image_path + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + + else: + t1_under_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB') + + t1_krecon_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_krecon_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB') + + + self.t2_images.append(t2_path) + self.t1_undermri_images.append(t1_under_path) + self.t2_undermri_images.append(t2_under_path) + + self.t1_krecon_images.append(t1_krecon_path) + self.t2_krecon_images.append(t2_krecon_path) + + self.transform = transform + self.input_normalize = input_normalize + + assert (len(self.t1_images) == len(self.t2_images)) + assert (len(self.t1_images) == len(self.t1_undermri_images)) + assert (len(self.t1_images) == len(self.t2_undermri_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + + t1_in = np.array(Image.open(self._base_dir + self.t1_undermri_images[index]))/255.0 + t1 = np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0 + t1_krecon = np.array(Image.open(self._base_dir + self.t1_krecon_images[index]))/255.0 + + t2_in = np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0 + t2 = np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0 + t2_krecon = np.array(Image.open(self._base_dir + self.t2_krecon_images[index]))/255.0 + + if self.input_normalize == "mean_std": + t1_in, t1_mean, t1_std = normalize_instance(t1_in, eps=1e-11) + t1 = normalize(t1, t1_mean, t1_std, eps=1e-11) + t2_in, t2_mean, t2_std = normalize_instance(t2_in, eps=1e-11) + t2 = normalize(t2, t2_mean, t2_std, eps=1e-11) + + t1_krecon = normalize(t1_krecon, t1_mean, t1_std, eps=1e-11) + t2_krecon = normalize(t2_krecon, t2_mean, t2_std, eps=1e-11) + + ### clamp input to ensure training stability. + t1_in = np.clip(t1_in, -6, 6) + t1 = np.clip(t1, -6, 6) + t2_in = np.clip(t2_in, -6, 6) + t2 = np.clip(t2, -6, 6) + + t1_krecon = np.clip(t1_krecon, -6, 6) + t2_krecon = np.clip(t2_krecon, -6, 6) + + sample_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + elif self.input_normalize == "min_max": + t1_in = (t1_in - t1_in.min())/(t1_in.max() - t1_in.min()) + t1 = (t1 - t1.min())/(t1.max() - t1.min()) + t2_in = (t2_in - t2_in.min())/(t2_in.max() - t2_in.min()) + t2 = (t2 - t2.min())/(t2.max() - t2.min()) + sample_stats = 0 + + elif self.input_normalize == "divide": + sample_stats = 0 + + + sample = {'image_in': t1_in, + 'image': t1, + 'image_krecon': t1_krecon, + 'target_in': t2_in, + 'target': t2, + 'target_krecon': t2_krecon} + + # print("images shape:", t1_in.shape, t1.shape, t2_in.shape, t2.shape) + + if self.transform is not None: + sample = self.transform(sample) + + return sample, sample_stats + + + +def add_gaussian_noise(img, mean=0, std=1): + noise = std * torch.randn_like(img) + mean + noisy_img = img + noise + return torch.clamp(noisy_img, 0, 1) + + + +class AddNoise(object): + def __call__(self, sample): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + add_gauss_noise = transforms.GaussianBlur(kernel_size=5) + add_poiss_noise = transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)) + + add_noise = transforms.RandomApply([add_gauss_noise, add_poiss_noise], p=0.5) + + img_in = add_noise(img_in) + target_in = add_noise(target_in) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + + return sample + + + + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 256, 256 + crop_size = 240 + pad_size = (256-240)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_krecon = sample['image_krecon'] + target_krecon = sample['target_krecon'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + img_krecon = np.pad(img_krecon, pad_size, mode='reflect') + target_krecon = np.pad(target_krecon, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + # print("img_in:", img_in.shape) + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + img_krecon = img_krecon[ww:ww+crop_size, hh:hh+crop_size] + target_krecon = target_krecon[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'image_krecon': img_krecon, \ + 'target_in': target_in, 'target': target, 'target_krecon': target_krecon} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + + +class RandomFlip(object): + def __call__(self, sample): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + # horizontal flip + if random.random() < 0.5: + img_in = cv2.flip(img_in, 1) + img = cv2.flip(img, 1) + target_in = cv2.flip(target_in, 1) + target = cv2.flip(target, 1) + + # vertical flip + if random.random() < 0.5: + img_in = cv2.flip(img_in, 0) + img = cv2.flip(img, 0) + target_in = cv2.flip(target_in, 0) + target = cv2.flip(target, 0) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + + + +class RandomRotate(object): + def __call__(self, sample, center=None, scale=1.0): + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + degrees = [0, 90, 180, 270] + angle = random.choice(degrees) + + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + + img_in = cv2.warpAffine(img_in, matrix, (w, h)) + img = cv2.warpAffine(img, matrix, (w, h)) + target_in = cv2.warpAffine(target_in, matrix, (w, h)) + target = cv2.warpAffine(target, matrix, (w, h)) + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + + image_krecon = sample['image_krecon'][:, :, None].transpose((2, 0, 1)) + target_krecon = sample['target_krecon'][:, :, None].transpose((2, 0, 1)) + + # print("img_in before_numpy range:", img_in.max(), img_in.min()) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + image_krecon = torch.from_numpy(image_krecon).float() + target_krecon = torch.from_numpy(target_krecon).float() + # print("img_in range:", img_in.max(), img_in.min()) + + return {'image_in': img_in, + 'image': img, + 'target_in': target_in, + 'target': target, + 'image_krecon': image_krecon, + 'target_krecon': target_krecon} diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/BRATS_kspace_dataloader.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/BRATS_kspace_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..871a153b20eac89e45ec0025e2aa31476360fde0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/BRATS_kspace_dataloader.py @@ -0,0 +1,298 @@ +""" +Load the low-quality and high-quality images from the BRATS dataset and transform to kspace. +""" + + +from __future__ import print_function, division +import numpy as np +import pandas as pd +from glob import glob +import random +from skimage import transform +from PIL import Image + +import cv2 +import os +import torch +from torch.utils.data import Dataset + +from .kspace_subsample import undersample_mri, mri_fft + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + + return data + + + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', MRIDOWN='4X', SNR=15, transform=None, input_normalize=None): + + super().__init__() + self._base_dir = base_dir + self._MRIDOWN = MRIDOWN + self.im_ids = [] + self.t2_images = [] + self.t1_undermri_images, self.t2_undermri_images = [], [] + self.splits_path = "/data/xiaohan/BRATS_dataset/cv_splits_100patients/" + + if split=='train': + self.train_file = self.splits_path + 'train_data.csv' + train_images = pd.read_csv(self.train_file).iloc[:, -1].values.tolist() + self.t1_images = [image for image in train_images if image.split('_')[-1]=='t1.png'] + + elif split=='test': + self.test_file = self.splits_path + 'test_data.csv' + # self.test_file = self.splits_path + 'train_data.csv' + test_images = pd.read_csv(self.test_file).iloc[:, -1].values.tolist() + # test_images = os.listdir(self._base_dir) + self.t1_images = [image for image in test_images if image.split('_')[-1]=='t1.png'] + + + for image_path in self.t1_images: + t2_path = image_path.replace('t1', 't2') + if SNR == 0: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_undermri') + t1_under_path = image_path + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_undermri') + else: + # t1_under_path = image_path.replace('t1', 't1_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + t1_under_path = image_path.replace('t1', 't1_' + str(SNR) + 'dB') + t2_under_path = image_path.replace('t1', 't2_' + self._MRIDOWN + '_' + str(SNR) + 'dB_undermri') + + self.t2_images.append(t2_path) + self.t1_undermri_images.append(t1_under_path) + self.t2_undermri_images.append(t2_under_path) + + # print("t1 images:", self.t1_images) + # print("t2 images:", self.t2_images) + # print("t1_undermri_images:", self.t1_undermri_images) + # print("t2_undermri_images:", self.t2_undermri_images) + + self.transform = transform + self.input_normalize = input_normalize + + assert (len(self.t1_images) == len(self.t2_images)) + assert (len(self.t1_images) == len(self.t1_undermri_images)) + assert (len(self.t1_images) == len(self.t2_undermri_images)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.t1_images))) + + def __len__(self): + return len(self.t1_images) + + + def __getitem__(self, index): + + # t1_in = np.array(Image.open(self._base_dir + self.t1_undermri_images[index]))/255.0 + t1 = np.array(Image.open(self._base_dir + self.t1_images[index]))/255.0 + # t2_in = np.array(Image.open(self._base_dir + self.t2_undermri_images[index]))/255.0 + t2 = np.array(Image.open(self._base_dir + self.t2_images[index]))/255.0 + # print("images:", t1_in.shape, t1.shape, t2_in.shape, t2.shape) + # print("t1 before standardization:", t1.max(), t1.min(), t1.mean()) + # print("t1 range:", t1.max(), t1.min()) + # print("t2 range:", t2.max(), t2 .min()) + + if self.input_normalize == "mean_std": + ### 对input image和target image都做(x-mean)/std的归一化操作 + t1, t1_mean, t1_std = normalize_instance(t1, eps=1e-11) + t2, t2_mean, t2_std = normalize_instance(t2, eps=1e-11) + + ### clamp input to ensure training stability. + t1 = np.clip(t1, -6, 6) + t2 = np.clip(t2, -6, 6) + # print("t1 after standardization:", t1.max(), t1.min(), t1.mean()) + + sample_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + elif self.input_normalize == "min_max": + # t1 = (t1 - t1.min())/(t1.max() - t1.min()) + # t2 = (t2 - t2.min())/(t2.max() - t2.min()) + t1 = t1/t1.max() + t2 = t2/t2.max() + sample_stats = 0 + + elif self.input_normalize == "divide": + sample_stats = 0 + + + ### convert images to kspace and perform undersampling. + # t1_kspace, t1_masked_kspace, t1_img, t1_under_img = undersample_mri(t1, _MRIDOWN = None) + t1_kspace, t1_img = mri_fft(t1) + t2_kspace, t2_masked_kspace, t2_img, t2_under_img, mask = undersample_mri(t2, _MRIDOWN = self._MRIDOWN) + + + sample = {'t1': t1_img, 't2': t2_img, 'under_t2': t2_under_img, "t2_mask": mask, \ + 't1_kspace': t1_kspace, 't2_kspace': t2_kspace, 't2_masked_kspace': t2_masked_kspace} + + if self.transform is not None: + sample = self.transform(sample) + + return sample, sample_stats + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + # print("img_in before_numpy range:", img_in.max(), img_in.min()) + img = torch.from_numpy(img).float() + target = torch.from_numpy(target).float() + # print("img_in range:", img_in.max(), img_in.min()) + + return {'ct': img, 'mri': target} + + +# class ToTensor(object): +# """Convert ndarrays in sample to Tensors.""" + +# def __call__(self, sample): +# # swap color axis because +# # numpy image: H x W x C +# # torch image: C X H X W +# img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) +# img = sample['image'][:, :, None].transpose((2, 0, 1)) +# target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) +# target = sample['target'][:, :, None].transpose((2, 0, 1)) +# # print("img_in before_numpy range:", img_in.max(), img_in.min()) +# img_in = torch.from_numpy(img_in).float() +# img = torch.from_numpy(img).float() +# target_in = torch.from_numpy(target_in).float() +# target = torch.from_numpy(target).float() +# # print("img_in range:", img_in.max(), img_in.min()) + +# return {'ct_in': img_in, +# 'ct': img, +# 'mri_in': target_in, +# 'mri': target} diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/fastmri.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/fastmri.py new file mode 100644 index 0000000000000000000000000000000000000000..0502ac9ccb96df4f55908cd92d5db432239659fe --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/fastmri.py @@ -0,0 +1,222 @@ +import csv +import os +import random +import xml.etree.ElementTree as etree +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +import pathlib + +import h5py +import numpy as np +import torch +import yaml +from torch.utils.data import Dataset +from .transforms import build_transforms +from matplotlib import pyplot as plt + + +def fetch_dir(key, data_config_file=pathlib.Path("fastmri_dirs.yaml")): + """ + Data directory fetcher. + + This is a brute-force simple way to configure data directories for a + project. Simply overwrite the variables for `knee_path` and `brain_path` + and this function will retrieve the requested subsplit of the data for use. + + Args: + key (str): key to retrieve path from data_config_file. + data_config_file (pathlib.Path, + default=pathlib.Path("fastmri_dirs.yaml")): Default path config + file. + + Returns: + pathlib.Path: The path to the specified directory. + """ + if not data_config_file.is_file(): + default_config = dict( + knee_path="/home/jc3/Data/", + brain_path="/home/jc3/Data/", + ) + with open(data_config_file, "w") as f: + yaml.dump(default_config, f) + + raise ValueError(f"Please populate {data_config_file} with directory paths.") + + with open(data_config_file, "r") as f: + data_dir = yaml.safe_load(f)[key] + + data_dir = pathlib.Path(data_dir) + + if not data_dir.exists(): + raise ValueError(f"Path {data_dir} from {data_config_file} does not exist.") + + return data_dir + + +def et_query( + root: etree.Element, + qlist: Sequence[str], + namespace: str = "http://www.ismrm.org/ISMRMRD", +) -> str: + """ + ElementTree query function. + This can be used to query an xml document via ElementTree. It uses qlist + for nested queries. + Args: + root: Root of the xml to search through. + qlist: A list of strings for nested searches, e.g. ["Encoding", + "matrixSize"] + namespace: Optional; xml namespace to prepend query. + Returns: + The retrieved data as a string. + """ + s = "." + prefix = "ismrmrd_namespace" + + ns = {prefix: namespace} + + for el in qlist: + s = s + f"//{prefix}:{el}" + + value = root.find(s, ns) + if value is None: + raise RuntimeError("Element not found") + + return str(value.text) + + +class SliceDataset(Dataset): + def __init__( + self, + root, + transform, + challenge, + sample_rate=1, + mode='train' + ): + self.mode = mode + + # challenge + if challenge not in ("singlecoil", "multicoil"): + raise ValueError('challenge should be either "singlecoil" or "multicoil"') + self.recons_key = ( + "reconstruction_esc" if challenge == "singlecoil" else "reconstruction_rss" + ) + # transform + self.transform = transform + + self.examples = [] + + self.cur_path = root + if not os.path.exists(self.cur_path): + self.cur_path = self.cur_path + "_selected" + + self.csv_file = "knee_data_split/singlecoil_" + self.mode + "_split_less.csv" + + with open(self.csv_file, 'r') as f: + reader = csv.reader(f) + + id = 0 + + for row in reader: + pd_metadata, pd_num_slices = self._retrieve_metadata(os.path.join(self.cur_path, row[0] + '.h5')) + + pdfs_metadata, pdfs_num_slices = self._retrieve_metadata(os.path.join(self.cur_path, row[1] + '.h5')) + + for slice_id in range(min(pd_num_slices, pdfs_num_slices)): + self.examples.append( + (os.path.join(self.cur_path, row[0] + '.h5'), os.path.join(self.cur_path, row[1] + '.h5') + , slice_id, pd_metadata, pdfs_metadata, id)) + id += 1 + + if sample_rate < 1: + random.shuffle(self.examples) + num_examples = round(len(self.examples) * sample_rate) + + self.examples = self.examples[0:num_examples] + + def __len__(self): + return len(self.examples) + + def __getitem__(self, i): + + # read pd + pd_fname, pdfs_fname, slice, pd_metadata, pdfs_metadata, id = self.examples[i] + + with h5py.File(pd_fname, "r") as hf: + pd_kspace = hf["kspace"][slice] + + pd_mask = np.asarray(hf["mask"]) if "mask" in hf else None + + pd_target = hf[self.recons_key][slice] if self.recons_key in hf else None + + attrs = dict(hf.attrs) + + attrs.update(pd_metadata) + + if self.transform is None: + pd_sample = (pd_kspace, pd_mask, pd_target, attrs, pd_fname, slice) + else: + pd_sample = self.transform(pd_kspace, pd_mask, pd_target, attrs, pd_fname, slice) + + with h5py.File(pdfs_fname, "r") as hf: + pdfs_kspace = hf["kspace"][slice] + pdfs_mask = np.asarray(hf["mask"]) if "mask" in hf else None + + pdfs_target = hf[self.recons_key][slice] if self.recons_key in hf else None + + attrs = dict(hf.attrs) + + attrs.update(pdfs_metadata) + + if self.transform is None: + pdfs_sample = (pdfs_kspace, pdfs_mask, pdfs_target, attrs, pdfs_fname, slice) + else: + pdfs_sample = self.transform(pdfs_kspace, pdfs_mask, pdfs_target, attrs, pdfs_fname, slice) + + + # dataset pdf mean and std tensor(3.1980e-05) tensor(1.3093e-05) + # print("dataset pdf mean and std", pdfs_sample[2], pdfs_sample[3]) + # print(pdfs_sample[1].shape, pdfs_sample[1].min(), pdfs_sample[1].max()) + + return (pd_sample, pdfs_sample, id) + + def _retrieve_metadata(self, fname): + with h5py.File(fname, "r") as hf: + et_root = etree.fromstring(hf["ismrmrd_header"][()]) + + enc = ["encoding", "encodedSpace", "matrixSize"] + enc_size = ( + int(et_query(et_root, enc + ["x"])), + int(et_query(et_root, enc + ["y"])), + int(et_query(et_root, enc + ["z"])), + ) + rec = ["encoding", "reconSpace", "matrixSize"] + recon_size = ( + int(et_query(et_root, rec + ["x"])), + int(et_query(et_root, rec + ["y"])), + int(et_query(et_root, rec + ["z"])), + ) + + lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"] + enc_limits_center = int(et_query(et_root, lims + ["center"])) + enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1 + + padding_left = enc_size[1] // 2 - enc_limits_center + padding_right = padding_left + enc_limits_max + + num_slices = hf["kspace"].shape[0] + + metadata = { + "padding_left": padding_left, + "padding_right": padding_right, + "encoding_size": enc_size, + "recon_size": recon_size, + } + + return metadata, num_slices + + +def build_dataset(args, mode='train', sample_rate=1): + assert mode in ['train', 'val', 'test'], 'unknown mode' + transforms = build_transforms(args, mode) + return SliceDataset(os.path.join(args.root_path, 'singlecoil_' + mode), transforms, 'singlecoil', sample_rate=sample_rate, mode=mode) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/hybrid_sparse.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/hybrid_sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..42e4a7e33c2204c13a1c4509897baf19e1fb07f1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/hybrid_sparse.py @@ -0,0 +1,156 @@ +from __future__ import print_function, division +import numpy as np +from glob import glob +import random +from skimage import transform + +import torch +from torch.utils.data import Dataset + +class Hybrid(Dataset): + + def __init__(self, base_dir=None, split='train', transform=None): + + super().__init__() + self._base_dir = base_dir + self.im_ids = [] + self.images = [] + self.gts = [] + + if split=='train': + self._image_dir = self._base_dir + imagelist = glob(self._image_dir+"/*_ct.png") + imagelist=sorted(imagelist) + for image_path in imagelist: + gt_path = image_path.replace('ct', 't1') + self.images.append(image_path) + self.gts.append(gt_path) + + + elif split=='test': + self._image_dir = self._base_dir + imagelist = glob(self._image_dir + "/*_ct.png") + imagelist=sorted(imagelist) + for image_path in imagelist: + gt_path = image_path.replace('ct', 't1') + self.images.append(image_path) + self.gts.append(gt_path) + + self.transform = transform + + assert (len(self.images) == len(self.gts)) + + # Display stats + print('Number of images in {}: {:d}'.format(split, len(self.images))) + + def __len__(self): + return len(self.images) + + + def __getitem__(self, index): + img_in, img, target_in, target= self._make_img_gt_point_pair(index) + sample = {'image_in': img_in, 'image':img, 'target_in': target_in, 'target': target} + # print("image in:", img_in.shape) + + if self.transform is not None: + sample = self.transform(sample) + + return sample + + + def _make_img_gt_point_pair(self, index): + # Read Image and Target + + # the default setting (i.e., rawdata.npz) is 4X64P + dd = np.load(self.images[index].replace('.png', '_raw_4X64P.npz')) + # print("images range:", dd['fbp'].max(), dd['ct'].max(), dd['under_t1'].max(), dd['t1'].max()) + _img_in = dd['fbp'] + _img_in[_img_in>0.6]=0.6 + _img_in = _img_in/0.6 + + _img = dd['ct'] + _img =(_img/1000*0.192+0.192) + _img[_img<0.0]=0.0 + _img[_img>0.6]=0.6 + _img = _img/0.6 + + _target_in = dd['under_t1'] + _target = dd['t1'] + + return _img_in, _img, _target_in, _target + +class RandomPadCrop(object): + def __call__(self, sample): + new_w, new_h = 400, 400 + crop_size = 384 + pad_size = (400-384)//2 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = np.pad(img_in, pad_size, mode='reflect') + img = np.pad(img, pad_size, mode='reflect') + target_in = np.pad(target_in, pad_size, mode='reflect') + target = np.pad(target, pad_size, mode='reflect') + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class RandomResizeCrop(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + new_w, new_h = 270, 270 + crop_size = 256 + img_in = sample['image_in'] + img = sample['image'] + target_in = sample['target_in'] + target = sample['target'] + + img_in = transform.resize(img_in, (new_h, new_w), order=3) + img = transform.resize(img, (new_h, new_w), order=3) + target_in = transform.resize(target_in, (new_h, new_w), order=3) + target = transform.resize(target, (new_h, new_w), order=3) + + ww = random.randint(0, np.maximum(0, new_w - crop_size)) + hh = random.randint(0, np.maximum(0, new_h - crop_size)) + + img_in = img_in[ww:ww+crop_size, hh:hh+crop_size] + img = img[ww:ww+crop_size, hh:hh+crop_size] + target_in = target_in[ww:ww+crop_size, hh:hh+crop_size] + target = target[ww:ww+crop_size, hh:hh+crop_size] + + sample = {'image_in': img_in, 'image': img, 'target_in': target_in, 'target': target} + return sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img_in = sample['image_in'][:, :, None].transpose((2, 0, 1)) + img = sample['image'][:, :, None].transpose((2, 0, 1)) + target_in = sample['target_in'][:, :, None].transpose((2, 0, 1)) + target = sample['target'][:, :, None].transpose((2, 0, 1)) + img_in = torch.from_numpy(img_in).float() + img = torch.from_numpy(img).float() + target_in = torch.from_numpy(target_in).float() + target = torch.from_numpy(target).float() + + return {'ct_in': img_in, + 'ct': img, + 'mri_in': target_in, + 'mri': target} diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/kspace_subsample.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/kspace_subsample.py new file mode 100644 index 0000000000000000000000000000000000000000..fb5b5694d8fee8b35ba8394fae98fe2d3aa25759 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/kspace_subsample.py @@ -0,0 +1,287 @@ +""" +2023/10/16, +preprocess kspace data with the undersampling mask in the fastMRI project. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +## mri related +def mri_fourier_transform_2d(image, mask): + ''' + image: input tensor [B, H, W, C] + mask: mask tensor [H, W] + ''' + spectrum = torch.fft.fftn(image, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + spectrum = torch.fft.fftshift(spectrum, dim=(1, 2)) + # Downsample k-space + masked_spectrum = spectrum * mask[None, :, :, None] + return spectrum, masked_spectrum + + +## mri related +def mri_inver_fourier_transform_2d(spectrum): + ''' + image: input tensor [B, H, W, C] + ''' + spectrum = torch.fft.ifftshift(spectrum, dim=(1, 2)) + image = torch.fft.ifftn(spectrum, dim=(1, 2), norm='ortho') + + return image + + +def add_gaussian_noise(kspace, snr): + ### 根据SNR确定noise的放大比例 + num_pixels = kspace.shape[0]*kspace.shape[1]*kspace.shape[2]*kspace.shape[3] + psr = torch.sum(torch.abs(kspace.real)**2)/num_pixels + pnr = psr/(np.power(10, snr/10)) + noise_r = torch.randn_like(kspace.real)*np.sqrt(pnr) + + psim = torch.sum(torch.abs(kspace.imag)**2)/num_pixels + pnim = psim/(np.power(10, snr/10)) + noise_im = torch.randn_like(kspace.imag)*np.sqrt(pnim) + + noise = noise_r + 1j*noise_im + noisy_kspace = kspace + noise + + return noisy_kspace + + +def mri_fft(raw_mri, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + spectrum = torch.fft.fftn(mri, dim=(1, 2), norm='ortho') + # K-space spectrum has been shifted to shift the zero-frequency component to the center of the spectrum + kspace = torch.fft.fftshift(spectrum, dim=(1, 2)) + + if _SNR > 0: + noisy_kspace = add_gaussian_noise(kspace, _SNR) + else: + noisy_kspace = kspace + + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + + return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ + kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1) + + + +def undersample_mri(raw_mri, _MRIDOWN, _SNR): + mri = torch.tensor(raw_mri)[None, :, :, None].to(torch.float32) + if _MRIDOWN == "4X": + mask_type_str, center_fraction, MRIDOWN = "random", 0.1, 4 + elif _MRIDOWN == "8X": + mask_type_str, center_fraction, MRIDOWN = "equispaced", 0.04, 8 + + ff = create_mask_for_mask_type(mask_type_str, [center_fraction], [MRIDOWN]) ## 0.2 for MRIDOWN=2, 0.1 for MRIDOWN=4, 0.04 for MRIDOWN=8 + + shape = [240, 240, 1] + mask = ff(shape, seed=1337) + mask = mask[:, :, 0] # [1, 240] + # print("mask:", mask.shape) + # print("original MRI:", mri) + + # print("original MRI:", mri.shape) + ### under-sample the kspace data. + kspace, masked_kspace = mri_fourier_transform_2d(mri, mask) + ### add low-field noise to the kspace data. + if _SNR > 0: + noisy_kspace = add_gaussian_noise(masked_kspace, _SNR) + else: + noisy_kspace = masked_kspace + + ### conver the corrupted kspace data back to noisy MRI image. + noisy_mri = mri_inver_fourier_transform_2d(noisy_kspace) + noisy_mri = torch.sqrt(torch.real(noisy_mri)**2 + torch.imag(noisy_mri)**2) + + return noisy_kspace[0].permute(2, 0, 1), noisy_mri[0].permute(2, 0, 1), \ + kspace[0].permute(2, 0, 1), mri[0].permute(2, 0, 1), mask.unsqueeze(-1) + + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/math.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/math.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0d76246b0c8fb757b6b589b7f889e0316696d0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/math.py @@ -0,0 +1,272 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +import numpy as np + +def complex_mul(x, y): + """ + Complex multiplication. + + This multiplies two complex tensors assuming that they are both stored as + real arrays with the last dimension being the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == y.shape[-1] == 2 + re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] + im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] + + return torch.stack((re, im), dim=-1) + + +def complex_conj(x): + """ + Complex conjugate. + + This applies the complex conjugate assuming that the input array has the + last dimension as the complex dimension. + + Args: + x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. + + Returns: + torch.Tensor: A PyTorch tensor with the last dimension of size 2. + """ + assert x.shape[-1] == 2 + + return torch.stack((x[..., 0], -x[..., 1]), dim=-1) + + +# def fft2c(data): +# """ +# Apply centered 2 dimensional Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The FFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# data = torch.fft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + +# def ifft2c(data): +# """ +# Apply centered 2-dimensional Inverse Fast Fourier Transform. + +# Args: +# data (torch.Tensor): Complex valued input data containing at least 3 +# dimensions: dimensions -3 & -2 are spatial dimensions and dimension +# -1 has size 2. All other dimensions are assumed to be batch +# dimensions. + +# Returns: +# torch.Tensor: The IFFT of the input. +# """ +# assert data.size(-1) == 2 +# data = ifftshift(data, dim=(-3, -2)) +# # data = torch.ifft(data, 2, normalized=True) +# data = torch.fft.ifft(data, 2, normalized=True) +# data = fftshift(data, dim=(-3, -2)) + +# return data + + + +def fft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2 dimensional Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.fft``. + + Returns: + The FFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.fftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + +def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: + """ + Apply centered 2-dimensional Inverse Fast Fourier Transform. + + Args: + data: Complex valued input data containing at least 3 dimensions: + dimensions -3 & -2 are spatial dimensions and dimension -1 has size + 2. All other dimensions are assumed to be batch dimensions. + norm: Normalization mode. See ``torch.fft.ifft``. + + Returns: + The IFFT of the input. + """ + if not data.shape[-1] == 2: + raise ValueError("Tensor does not have separate complex dim.") + + data = ifftshift(data, dim=[-3, -2]) + data = torch.view_as_real( + torch.fft.ifftn( # type: ignore + torch.view_as_complex(data), dim=(-2, -1), norm=norm + ) + ) + data = fftshift(data, dim=[-3, -2]) + + return data + + + + + +def complex_abs(data): + """ + Compute the absolute value of a complex valued input tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Absolute value of data. + """ + assert data.size(-1) == 2 + + return (data ** 2).sum(dim=-1).sqrt() + + + +def complex_abs_numpy(data): + assert data.shape[-1] == 2 + + return np.sqrt(np.sum(data ** 2, axis=-1)) + + +def complex_abs_sq(data):#multi coil + """ + Compute the squared absolute value of a complex tensor. + + Args: + data (torch.Tensor): A complex valued tensor, where the size of the + final dimension should be 2. + + Returns: + torch.Tensor: Squared absolute value of data. + """ + assert data.size(-1) == 2 + return (data ** 2).sum(dim=-1) + + +# Helper functions + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors. + + Args: + x (torch.Tensor): A PyTorch tensor. + shift (int): Amount to roll. + dim (int): Which dimension to roll. + + Returns: + torch.Tensor: Rolled version of x. + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to fftshift. + + Returns: + torch.Tensor: fftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x (torch.Tensor): A PyTorch tensor. + dim (int): Which dimension to ifftshift. + + Returns: + torch.Tensor: ifftshifted version of x. + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + + return roll(x, shift, dim) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data + """ + data = data.numpy() + return data[..., 0] + 1j * data[..., 1] diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/subsample.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/subsample.py new file mode 100644 index 0000000000000000000000000000000000000000..d0620da3414c6077e4293376fb8a9be01ad19990 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/subsample.py @@ -0,0 +1,195 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import contextlib + +import numpy as np +import torch + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +class MaskFunc(object): + """ + An object for GRAPPA-style sampling masks. + + This crates a sampling mask that densely samples the center while + subsampling outer k-space regions based on the undersampling factor. + """ + + def __init__(self, center_fractions, accelerations): + """ + Args: + center_fractions (List[float]): Fraction of low-frequency columns to be + retained. If multiple values are provided, then one of these + numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have + the same length as center_fractions. If multiple values are + provided, then one of these is chosen uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError( + "Number of center fractions should match number of accelerations" + ) + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random + + def choose_acceleration(self): + """Choose acceleration based on class parameters.""" + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + + return center_fraction, acceleration + + +class RandomMaskFunc(MaskFunc): + """ + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding to low-frequencies. + 2. The other columns are selected uniformly at random with a + probability equal to: prob = (N / acceleration - N_low_freqs) / + (N - N_low_freqs). This ensures that the expected number of columns + selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the RandomMaskFunc object is called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], + then there is a 50% probability that 4-fold acceleration with 8% center + fraction is selected and a 50% probability that 8-fold acceleration with 4% + center fraction is selected. + """ + + def __call__(self, shape, seed=None): + """ + Create the mask. + + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the + k-space data has N columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center + corresponding tovlow-frequencies. + 2. The other columns are selected with equal spacing at a proportion + that reaches the desired acceleration rate taking into consideration + the number of low frequencies. This ensures that the expected number + of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which + case one possible (center_fraction, acceleration) is chosen uniformly at + random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require + modifications to standard GRAPPA approaches. Nonetheless, this aspect of + the function has been preserved to match the public multicoil data. + """ + + def __call__(self, shape, seed): + """ + Args: + shape (iterable[int]): The shape of the mask to be created. The + shape should have at least 3 dimensions. Samples are drawn + along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting + the seed ensures the same mask is generated each time for the + same shape. The random state is reset afterwards. + + Returns: + torch.Tensor: A mask of the specified shape. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + num_cols = shape[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( + num_low_freqs * acceleration - num_cols + ) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/transforms.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..536eecc5bef52a969001f5f68fc91a38fdc549ba --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/dataloaders/transforms.py @@ -0,0 +1,485 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import numpy as np +import torch + +from .math import ifft2c, fft2c, complex_abs +from .subsample import create_mask_for_mask_type, MaskFunc +import random + +from typing import Dict, Optional, Sequence, Tuple, Union +from matplotlib import pyplot as plt +import os + +def rss(data, dim=0): + """ + Compute the Root Sum of Squares (RSS). + + RSS is computed assuming that dim is the coil dimension. + + Args: + data (torch.Tensor): The input tensor + dim (int): The dimensions along which to apply the RSS transform + + Returns: + torch.Tensor: The RSS value. + """ + return torch.sqrt((data ** 2).sum(dim)) + + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. + + For complex arrays, the real and imaginary parts are stacked along the last + dimension. + + Args: + data (np.array): Input numpy array. + + Returns: + torch.Tensor: PyTorch version of data. + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + + return torch.from_numpy(data) + + +def tensor_to_complex_np(data): + """ + Converts a complex torch tensor to numpy array. + + Args: + data (torch.Tensor): Input data to be converted to numpy. + + Returns: + np.array: Complex numpy version of data. + """ + data = data.numpy() + + return data[..., 0] + 1j * data[..., 1] + + +def apply_mask(data, mask_func, seed=None, padding=None): + """ + Subsample given k-space by multiplying with a mask. + + Args: + data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where + dimensions -3 and -2 are the spatial dimensions, and the final dimension has size + 2 (for complex values). + mask_func (callable): A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed (int or 1-d array_like, optional): Seed for the random number generator. + + Returns: + (tuple): tuple containing: + masked data (torch.Tensor): Subsampled k-space data + mask (torch.Tensor): The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + + +def mask_center(x, mask_from, mask_to): + mask = torch.zeros_like(x) + mask[:, :, :, mask_from:mask_to] = x[:, :, :, mask_from:mask_to] + + return mask + + +def center_crop(data, shape): + """ + Apply a center crop to the input real image or batch of real images. + + Args: + data (torch.Tensor): The input tensor to be center cropped. It should + have at least 2 dimensions and the cropping is applied along the + last two dimensions. + shape (int, int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image. + """ + assert 0 < shape[0] <= data.shape[-2] + assert 0 < shape[1] <= data.shape[-1] + + w_from = (data.shape[-2] - shape[0]) // 2 + h_from = (data.shape[-1] - shape[1]) // 2 + w_to = w_from + shape[0] + h_to = h_from + shape[1] + + return data[..., w_from:w_to, h_from:h_to] + + +def complex_center_crop(data, shape): + """ + Apply a center crop to the input image or batch of complex images. + + Args: + data (torch.Tensor): The complex input tensor to be center cropped. It + should have at least 3 dimensions and the cropping is applied along + dimensions -3 and -2 and the last dimensions should have a size of + 2. + shape (int): The output shape. The shape should be smaller than + the corresponding dimensions of data. + + Returns: + torch.Tensor: The center cropped image + """ + assert 0 < shape[0] <= data.shape[-3] + assert 0 < shape[1] <= data.shape[-2] + + w_from = (data.shape[-3] - shape[0]) // 2 #80 + h_from = (data.shape[-2] - shape[1]) // 2 #80 + w_to = w_from + shape[0] #240 + h_to = h_from + shape[1] #240 + + return data[..., w_from:w_to, h_from:h_to, :] + + +def center_crop_to_smallest(x, y): + """ + Apply a center crop on the larger image to the size of the smaller. + + The minimum is taken over dim=-1 and dim=-2. If x is smaller than y at + dim=-1 and y is smaller than x at dim=-2, then the returned dimension will + be a mixture of the two. + + Args: + x (torch.Tensor): The first image. + y (torch.Tensor): The second image + + Returns: + tuple: tuple of tensors x and y, each cropped to the minimim size. + """ + smallest_width = min(x.shape[-1], y.shape[-1]) + smallest_height = min(x.shape[-2], y.shape[-2]) + x = center_crop(x, (smallest_height, smallest_width)) + y = center_crop(y, (smallest_height, smallest_width)) + + return x, y + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + + Applies the formula (data - mean) / (stddev + eps). + + Args: + data (torch.Tensor): Input data to be normalized. + mean (float): Mean value. + stddev (float): Standard deviation. + eps (float, default=0.0): Added to stddev to prevent dividing by zero. + + Returns: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize the given tensor with instance norm/ + + Applies the formula (data - mean) / (stddev + eps), where mean and stddev + are computed from the data itself. + + Args: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + + Returns: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + + return normalize(data, mean, std, eps), mean, std + + +class DataTransform(object): + """ + Data Transformer for training U-Net models. + """ + + def __init__(self, which_challenge): + """ + Args: + which_challenge (str): Either "singlecoil" or "multicoil" denoting + the dataset. + mask_func (fastmri.data.subsample.MaskFunc): A function that can + create a mask of appropriate shape. + use_seed (bool): If true, this class computes a pseudo random + number generator seed from the filename. This ensures that the + same mask is used for all the slices of a given volume every + time. + """ + if which_challenge not in ("singlecoil", "multicoil"): + raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"') + + self.which_challenge = which_challenge + + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + """ + Args: + kspace (numpy.array): Input k-space of shape (num_coils, rows, + cols, 2) for multi-coil data or (rows, cols, 2) for single coil + data. + mask (numpy.array): Mask from the test dataset. + target (numpy.array): Target image. + attrs (dict): Acquisition related information stored in the HDF5 + object. + fname (str): File name. + slice_num (int): Serial number of the slice. + + Returns: + (tuple): tuple containing: + image (torch.Tensor): Zero-filled input image. + target (torch.Tensor): Target image converted to a torch + Tensor. + mean (float): Mean value used for normalization. + std (float): Standard deviation value used for normalization. + fname (str): File name. + slice_num (int): Serial number of the slice. + """ + kspace = to_tensor(kspace) + + # inverse Fourier transform to get zero filled solution + image = ifft2c(kspace) + + # crop input to correct size + if target is not None: + crop_size = (target.shape[-2], target.shape[-1]) + else: + crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) + + # check for sFLAIR 203 + if image.shape[-2] < crop_size[1]: + crop_size = (image.shape[-2], image.shape[-2]) + + image = complex_center_crop(image, crop_size) + + # getLR + imgfft = fft2c(image) + imgfft = complex_center_crop(imgfft, (160, 160)) + LR_image = ifft2c(imgfft) + + # absolute value + LR_image = complex_abs(LR_image) + + # normalize input + LR_image, mean, std = normalize_instance(LR_image, eps=1e-11) + LR_image = LR_image.clamp(-6, 6) + + # normalize target + if target is not None: + target = to_tensor(target) + target = center_crop(target, crop_size) + target = normalize(target, mean, std, eps=1e-11) + target = target.clamp(-6, 6) + else: + target = torch.Tensor([0]) + + return LR_image, target, mean, std, fname, slice_num + +class DenoiseDataTransform(object): + def __init__(self, size, noise_rate): + super(DenoiseDataTransform, self).__init__() + self.size = (size, size) + self.noise_rate = noise_rate + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + max_value = attrs["max"] + + #target + target = to_tensor(target) + target = center_crop(target, self.size) + target, mean, std = normalize_instance(target, eps=1e-11) + target = target.clamp(-6, 6) + + #image + kspace = to_tensor(kspace) + complex_image = ifft2c(kspace) #complex_image + image = complex_center_crop(complex_image, self.size) + noise_image = self.rician_noise(image, max_value) + noise_image = complex_abs(noise_image) + + noise_image = normalize(noise_image, mean, std, eps=1e-11) + noise_image = noise_image.clamp(-6, 6) + + return noise_image, target, mean, std, fname, slice_num + + + def rician_noise(self, X, noise_std): + #Add rician noise with variance sampled uniformly from the range 0 and 0.1 + noise_std = random.uniform(0, noise_std*self.noise_rate) + Ir = X + noise_std * torch.randn(X.shape) + Ii = noise_std*torch.randn(X.shape) + In = torch.sqrt(Ir ** 2 + Ii ** 2) + return In + + +def apply_mask( + data: torch.Tensor, + mask_func: MaskFunc, + seed: Optional[Union[int, Tuple[int, ...]]] = None, + padding: Optional[Sequence[int]] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Subsample given k-space by multiplying with a mask. + Args: + data: The input k-space data. This should have at least 3 dimensions, + where dimensions -3 and -2 are the spatial dimensions, and the + final dimension has size 2 (for complex values). + mask_func: A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed: Seed for the random number generator. + padding: Padding value to apply for mask. + Returns: + tuple containing: + masked data: Subsampled k-space data + mask: The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + mask = mask_func(shape, seed) + if padding is not None: + mask[:, :, : padding[0]] = 0 + mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros + + masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros + + return masked_data, mask + + + +class ReconstructionTransform(object): + """ + Data Transformer for training U-Net models. + """ + + def __init__(self, which_challenge, mask_func=None, use_seed=True): + """ + Args: + which_challenge (str): Either "singlecoil" or "multicoil" denoting + the dataset. + mask_func (fastmri.data.subsample.MaskFunc): A function that can + create a mask of appropriate shape. + use_seed (bool): If true, this class computes a pseudo random + number generator seed from the filename. This ensures that the + same mask is used for all the slices of a given volume every + time. + """ + if which_challenge not in ("singlecoil", "multicoil"): + raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"') + + self.mask_func = mask_func + self.which_challenge = which_challenge + self.use_seed = use_seed + + def __call__(self, kspace, mask, target, attrs, fname, slice_num): + """ + Args: + kspace (numpy.array): Input k-space of shape (num_coils, rows, + cols, 2) for multi-coil data or (rows, cols, 2) for single coil + data. + mask (numpy.array): Mask from the test dataset. + target (numpy.array): Target image. + attrs (dict): Acquisition related information stored in the HDF5 + object. + fname (str): File name. + slice_num (int): Serial number of the slice. + + Returns: + (tuple): tuple containing: + image (torch.Tensor): Zero-filled input image. + target (torch.Tensor): Target image converted to a torch + Tensor. + mean (float): Mean value used for normalization. + std (float): Standard deviation value used for normalization. + fname (str): File name. + slice_num (int): Serial number of the slice. + """ + kspace = to_tensor(kspace) + + # apply mask + if self.mask_func: + seed = None if not self.use_seed else tuple(map(ord, fname)) + masked_kspace, mask = apply_mask(kspace, self.mask_func, seed) + else: + masked_kspace = kspace + + # inverse Fourier transform to get zero filled solution + image = ifft2c(masked_kspace) + + # crop input to correct size + if target is not None: + crop_size = (target.shape[-2], target.shape[-1]) + else: + crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) + + # check for sFLAIR 203 + if image.shape[-2] < crop_size[1]: + crop_size = (image.shape[-2], image.shape[-2]) + + image = complex_center_crop(image, crop_size) + # print('image',image.shape) + # absolute value + image = complex_abs(image) + + # apply Root-Sum-of-Squares if multicoil data + if self.which_challenge == "multicoil": + image = rss(image) + + # normalize input + image, mean, std = normalize_instance(image, eps=1e-11) + image = image.clamp(-6, 6) + + # normalize target + if target is not None: + target = to_tensor(target) + target = center_crop(target, crop_size) + target = normalize(target, mean, std, eps=1e-11) + target = target.clamp(-6, 6) + else: + target = torch.Tensor([0]) + + return image, target, mean, std, fname, slice_num + + +def build_transforms(args, mode = 'train'): + + challenge = 'singlecoil' + if mode == 'train': + mask = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + return ReconstructionTransform(challenge, mask, use_seed=False) + elif mode == 'val': + mask = create_mask_for_mask_type( + args.MASKTYPE, args.CENTER_FRACTIONS, args.ACCELERATIONS, + ) + return ReconstructionTransform(challenge, mask) + else: + return ReconstructionTransform(challenge) + + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/documents/INSTALL.md b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/documents/INSTALL.md new file mode 100644 index 0000000000000000000000000000000000000000..9912721cb3354240d99c08838ae8d2b1417b339b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/documents/INSTALL.md @@ -0,0 +1,11 @@ +## Dependency +The code is tested on `python 3.8, Pytorch 1.13`. + +##### Setup environment + +```bash +conda create -n FSMNet python=3.8 +source activate FSMNet # or conda activate FSMNet +pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 +pip install einops h5py matplotlib scikit_image tensorboardX yacs pandas opencv-python timm ml_collections +``` diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/figures/FSMNet.png b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/figures/FSMNet.png new file mode 100644 index 0000000000000000000000000000000000000000..127848f2c580c8d91d9cff8890500e5f3c830d72 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/figures/FSMNet.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40bb9cbda0a8f926ea4ef8d92228ce591766b1d3176000db5758b2edf1a6249b +size 378629 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/knee_data_split/singlecoil_train_split_less.csv b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/knee_data_split/singlecoil_train_split_less.csv new file mode 100644 index 0000000000000000000000000000000000000000..d85707318750900b14a6e7100541242a60b7a310 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/knee_data_split/singlecoil_train_split_less.csv @@ -0,0 +1,227 @@ +file1000685,file1000568,0.301723929779229 +file1002273,file1000481,0.302226224199571 +file1000472,file1000142,0.304272730770318 +file1002186,file1000863,0.304812175768496 +file1002385,file1002518,0.305357274240413 +file1000981,file1000129,0.305533361411383 +file1001320,file1001948,0.306821514316368 +file1000633,file1002243,0.306892354331709 +file1001872,file1001294,0.308345907393103 +file1001474,file1001830,0.310481695157561 +file1001005,file1001283,0.310497722435023 +file1001690,file1001519,0.310709448786299 +file1002469,file1001811,0.31193137253455 +file1000914,file1000242,0.31237190359308 +file1002284,file1002012,0.315366393843169 +file1001721,file1001328,0.31735122361847 +file1000807,file1002334,0.320096908959039 +file1001944,file1002335,0.320272061156991 +file1002090,file1002431,0.320351887633851 +file1000499,file1002063,0.320786426659383 +file1001362,file1000509,0.32175341740359 +file1001421,file1000597,0.324291432700032 +file1000349,file1000321,0.324545110048573 +file1002123,file1001235,0.327142348994532 +file1001867,file1002086,0.328624781732941 +file1001007,file1001027,0.330759860300298 +file1001915,file1000088,0.331499371283099 +file1001661,file1000313,0.331905252950291 +file1000383,file1000307,0.339998107225229 +file1000116,file1000632,0.34069458535013 +file1002303,file1000173,0.343821267871409 +file1000306,file1001277,0.344751178043605 +file1000003,file1001922,0.346138116633394 +file1000109,file1000143,0.347632265547478 +file1001999,file1000115,0.348248659775587 +file1000089,file1000326,0.348964657514049 +file1001205,file1002232,0.349375610862454 +file1000557,file1000619,0.351305005151048 +file1001823,file1000778,0.352076809462453 +file1000806,file1001130,0.352659078122633 +file1000365,file1000351,0.352772816610486 +file1002374,file1001778,0.352974481603711 +file1002516,file1001910,0.359896103026675 +file1001200,file1000931,0.360070003966827 +file1001479,file1000952,0.360424533696936 +file1000850,file1001942,0.362632797518558 +file1001426,file1002143,0.363271909822866 +file1001304,file1001333,0.36404737582222 +file1000390,file1000518,0.364744579516818 +file1000830,file1002096,0.365897427529429 +file1000794,file1001856,0.365973692948894 +file1001266,file1001327,0.366395851089761 +file1001692,file1002352,0.36655953875445 +file1001564,file1001024,0.367284385415205 +file1001861,file1002050,0.36783497787384 +file1002066,file1002361,0.367964419694875 +file1001613,file1002087,0.368231014746024 +file1001931,file1000220,0.368847112914793 +file1000339,file1000554,0.370123905662701 +file1000754,file1002208,0.37031588493778 +file1001067,file1001956,0.371313060558732 +file1000101,file1001053,0.372141932838775 +file1002520,file1002409,0.372501194473693 +file1001459,file1001615,0.373295536945146 +file1001673,file1000508,0.376416667681519 +file1002201,file1001228,0.376680033570078 +file1000058,file1002449,0.376927627737029 +file1001748,file1001042,0.378067114701689 +file1001941,file1000376,0.37841176147662 +file1000801,file1002545,0.378423759459738 +file1000010,file1000535,0.38111194591455 +file1000882,file1002154,0.382223600234592 +file1001694,file1001297,0.382545161354354 +file1001992,file1002456,0.382664563820782 +file1001666,file1001773,0.382892588770697 +file1001629,file1002514,0.383417073960824 +file1002113,file1000738,0.385439884728523 +file1002221,file1000569,0.385903801966773 +file1002296,file1002117,0.387319754665673 +file1000693,file1001945,0.387855926202209 +file1001410,file1000223,0.391284037867147 +file1002071,file1001425,0.391497653794399 +file1002325,file1001259,0.391913965917762 +file1002430,file1001969,0.392256443856501 +file1002462,file1000708,0.393161981208355 +file1002358,file1001888,0.39427809496515 +file1000485,file1000753,0.395316199436001 +file1002357,file1001973,0.39564210237905 +file1002130,file1002041,0.395978941103639 +file1002569,file1000097,0.397496127623486 +file1002264,file1000148,0.397630184088734 +file1002381,file1001401,0.398105992102355 +file1000289,file1000585,0.399527637723015 +file1002368,file1001723,0.400243022234875 +file1002342,file1001319,0.400431803928825 +file1002170,file1001226,0.400632448147846 +file1001385,file1001758,0.400855988878681 +file1001732,file1002541,0.40091828863264 +file1001102,file1000762,0.400923140595936 +file1001470,file1000181,0.401353492516182 +file1000400,file1000884,0.401562860630016 +file1002293,file1002523,0.401800994807451 +file1000728,file1001654,0.402763341041675 +file1000582,file1001491,0.403451830806034 +file1000586,file1001521,0.403648293267187 +file1002287,file1001770,0.405194821414496 +file1000371,file1000159,0.405999000381268 +file1002356,file1002064,0.406519210876811 +file1000324,file1000590,0.407593694425997 +file1001622,file1001710,0.40759525378577 +file1002037,file1000403,0.407814136488744 +file1002444,file1000743,0.40943197761463 +file1001175,file1002088,0.410423663035312 +file1001391,file1000540,0.410854355646853 +file1002133,file1001186,0.411248429534111 +file1001229,file1001630,0.411355571792039 +file1002283,file1000402,0.411836769927671 +file1000627,file1000161,0.412089060388579 +file1001701,file1001402,0.412854774524637 +file1000795,file1000452,0.413448916432685 +file1000354,file1000947,0.41459642292987 +file1002043,file1002505,0.414863932355455 +file1001285,file1001113,0.418183757940871 +file1000170,file1001832,0.419441549204313 +file1002399,file1001500,0.419905873946513 +file1002439,file1000177,0.42054051043224 +file1001656,file1001217,0.420597020703942 +file1000296,file1000065,0.420845042251081 +file1000626,file1001623,0.42087934790355 +file1001767,file1000760,0.422315537515139 +file1000467,file1001246,0.422371268999111 +file1001033,file1000611,0.42425275873442 +file1002304,file1000221,0.425602179771197 +file1001737,file1001141,0.425716789218234 +file1001565,file1000559,0.426158561043574 +file1000249,file1000643,0.426541100077021 +file1002014,file1001109,0.426587840438723 +file1002006,file1000790,0.427829459781438 +file1000193,file1000750,0.428103808477214 +file1001993,file1001110,0.428186367615143 +file1002094,file1001814,0.428868578868176 +file1000098,file1001420,0.428968675677784 +file1000336,file1000211,0.430347427208789 +file1001498,file1002568,0.43204475404071 +file1001671,file1001106,0.432215802861284 +file1000426,file1002386,0.43283446816702 +file1001520,file1002481,0.434867670495723 +file1002189,file1001432,0.434924370194975 +file1001390,file1002554,0.435313848731387 +file1002166,file1001982,0.435387512979012 +file1001120,file1001006,0.435594761785839 +file1000149,file1001985,0.436289528591294 +file1001632,file1001008,0.436682374331417 +file1002567,file1001155,0.437221000601772 +file1000434,file1002195,0.438098100114814 +file1002532,file1001048,0.438500899539101 +file1001605,file1000927,0.438686659342641 +file1000479,file1000120,0.439587267995034 +file1002473,file1001388,0.439594997597548 +file1001108,file1002228,0.440528754793898 +file1002099,file1002056,0.440776843467602 +file1000191,file1002127,0.441114509542672 +file1000875,file1002494,0.441378135507993 +file1002161,file1000002,0.441912476744187 +file1002269,file1001220,0.442742296865228 +file1001295,file1001355,0.4435162405589 +file1001659,file1001023,0.444686151316673 +file1001857,file1001378,0.447500830900898 +file1001183,file1001370,0.447782748040587 +file1000428,file1000859,0.448328910257083 +file1000588,file1002227,0.448650488897259 +file1001098,file1000486,0.448862467740607 +file1001288,file1000408,0.450363676957042 +file1002097,file1001210,0.451126832474666 +file1000216,file1001082,0.451550143520946 +file1001746,file1001642,0.451781042569196 +file1002388,file1000204,0.451940333555972 +file1000021,file1000560,0.452234621797968 +file1000489,file1001545,0.452796032302523 +file1001116,file1000883,0.453096911915119 +file1001372,file1000561,0.45532542913335 +file1001276,file1000424,0.45534174289324 +file1000974,file1002098,0.455371894001872 +file1002566,file1002044,0.455937677517583 +file1000262,file1002046,0.456056330767294 +file1001619,file1001342,0.456559091350965 +file1000045,file1001616,0.457599407743834 +file1001468,file1002115,0.458095965024278 +file1001061,file1000233,0.460561351667266 +file1000558,file1000100,0.461094222462111 +file1000605,file1000691,0.461429521647285 +file1000640,file1000384,0.463383466503099 +file1000410,file1001358,0.463452482427773 +file1000851,file1001014,0.463558384057952 +file1001092,file1000138,0.463591264436099 +file1000061,file1002049,0.465778207162619 +file1001206,file1000983,0.466701211830884 +file1000256,file1000475,0.466865377968187 +file1002434,file1001387,0.467154181996099 +file1001036,file1000210,0.470404279499276 +file1001540,file1001860,0.472822271037545 +file1001244,file1001154,0.475076170733515 +file1000131,file1001526,0.475459563440874 +file1000180,file1002045,0.476814451110009 +file1001837,file1000637,0.478851985878026 +file1002425,file1001891,0.481451070031007 +file1001056,file1000682,0.482320170742015 +file1002276,file1000777,0.483452141843029 +file1001139,file1002544,0.487462418948035 +file1000548,file1001257,0.488098081542811 +file1000188,file1001286,0.488423105111001 +file1001879,file1000999,0.488449105381724 +file1001062,file1000231,0.48930683373911 +file1000040,file1001873,0.492070802214623 +file1002286,file1000066,0.493213986773381 +file1002474,file1002563,0.501584439120211 +file1000967,file1000563,0.502066261411662 +file1001307,file1002048,0.50460435259807 +file1000483,file1001699,0.511819026566198 +file1001528,file1000285,0.512629017841038 +file1001742,file1002371,0.513805213204644 +file1002397,file1000592,0.515406473057 +file1000069,file1000510,0.528220553613126 +file1001087,file1001300,0.536510449049583 +file1001991,file1000836,0.538145797125916 +file1001382,file1001806,0.538539506621535 +file1000111,file1001189,0.557690760784602 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/knee_data_split/singlecoil_val_split_less.csv b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/knee_data_split/singlecoil_val_split_less.csv new file mode 100644 index 0000000000000000000000000000000000000000..a1cbac5537562063359f4ac3e0985de51cb989b2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/knee_data_split/singlecoil_val_split_less.csv @@ -0,0 +1,45 @@ +file1000323,file1002538,0.30754967523156 +file1001458,file1001566,0.310512744537048 +file1000885,file1001059,0.318226346221521 +file1000464,file1000196,0.321465466968232 +file1000314,file1000178,0.327505552363568 +file1001163,file1001289,0.328954963947692 +file1000033,file1001191,0.330925609207301 +file1000976,file1000990,0.344036229323198 +file1001930,file1001834,0.345994076497818 +file1002546,file1001344,0.351762252794677 +file1000277,file1001429,0.353297786572139 +file1001893,file1001262,0.358064285890878 +file1000926,file1002067,0.360639004205491 +file1001650,file1002002,0.362186928073579 +file1001184,file1001655,0.362592305723707 +file1001497,file1001338,0.365599407221502 +file1001202,file1001365,0.3844323497275 +file1001126,file1002340,0.388929627976346 +file1001339,file1000291,0.391300537691403 +file1002187,file1001862,0.39883786878841 +file1000041,file1000591,0.39896683485823 +file1001064,file1001850,0.399687813966601 +file1001331,file1002214,0.400340820924839 +file1000831,file1000528,0.403582747590964 +file1000769,file1000538,0.405298051020298 +file1000182,file1001968,0.407646172205036 +file1002382,file1001651,0.410749052045234 +file1000660,file1000476,0.415423894745454 +file1002570,file1001726,0.424622351472032 +file1001585,file1000858,0.426738511964108 +file1000190,file1000593,0.428080574167047 +file1001170,file1001090,0.429987089825525 +file1002252,file1001440,0.432038842370013 +file1000697,file1001144,0.432558506761396 +file1001077,file1000000,0.441922503777368 +file1001381,file1001119,0.455418270809002 +file1001759,file1001851,0.460824505737749 +file1000635,file1002389,0.465674267492171 +file1001668,file1001689,0.467330511330772 +file1001221,file1000818,0.469630000354232 +file1001298,file1002145,0.473526387887779 +file1001763,file1001938,0.47398893150184 +file1001444,file1000942,0.48507438696692 +file1000735,file1002007,0.496530240691134 +file1000477,file1000280,0.528508000547834 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/metric.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..53ddb27a96bab67975beef06ca6819e628208153 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/metric.py @@ -0,0 +1,51 @@ + +import numpy as np +from skimage.metrics import peak_signal_noise_ratio, structural_similarity + +def nmse(gt, pred): + """Compute Normalized Mean Squared Error (NMSE)""" + return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2 + + +def psnr(gt, pred): + """Compute Peak Signal to Noise Ratio metric (PSNR)""" + return peak_signal_noise_ratio(gt, pred, data_range=gt.max()) + + +def ssim(gt, pred, maxval=None): + """Compute Structural Similarity Index Metric (SSIM)""" + maxval = gt.max() if maxval is None else maxval + + ssim = 0 + for slice_num in range(gt.shape[0]): + ssim = ssim + structural_similarity( + gt[slice_num], pred[slice_num], data_range=maxval + ) + + ssim = ssim / gt.shape[0] + + return ssim + + +class AverageMeter(object): + """Computes and stores the average and current value. + + Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 + """ + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self.score = [] + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + self.score.append(val) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/common_freq.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/common_freq.py new file mode 100644 index 0000000000000000000000000000000000000000..79cf3e778029a846b4da910c115c8315bf33dbaf --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/common_freq.py @@ -0,0 +1,389 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange +class Scale(nn.Module): + + def __init__(self, init_value=1e-3): + super().__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale + + +class invPixelShuffle(nn.Module): + def __init__(self, ratio=2): + super(invPixelShuffle, self).__init__() + self.ratio = ratio + + def forward(self, tensor): + ratio = self.ratio + b = tensor.size(0) + ch = tensor.size(1) + y = tensor.size(2) + x = tensor.size(3) + assert x % ratio == 0 and y % ratio == 0, 'x, y, ratio : {}, {}, {}'.format(x, y, ratio) + tensor = tensor.view(b, ch, y // ratio, ratio, x // ratio, ratio).permute(0, 1, 3, 5, 2, 4) + return tensor.contiguous().view(b, -1, y // ratio, x // ratio) + + +class UpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(in_channels=n_feats, out_channels=4 * n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PixelShuffle(upscale_factor=2)) + m.append(nn.PReLU()) + super(UpSampler, self).__init__(*m) + + +class InvUpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(invPixelShuffle(2)) + m.append(nn.Conv2d(in_channels=n_feats * 4, out_channels=n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PReLU()) + super(InvUpSampler, self).__init__(*m) + + + +class AdaptiveNorm(nn.Module): + def __init__(self, n): + super(AdaptiveNorm, self).__init__() + + self.w_0 = nn.Parameter(torch.Tensor([1.0])) + self.w_1 = nn.Parameter(torch.Tensor([0.0])) + + self.bn = nn.BatchNorm2d(n, momentum=0.999, eps=0.001) + + def forward(self, x): + return self.w_0 * x + self.w_1 * self.bn(x) + + +class ConvBNReLU2D(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, + bias=False, act=None, norm=None): + super(ConvBNReLU2D, self).__init__() + + self.layers = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + self.act = None + self.norm = None + if norm == 'BN': + self.norm = torch.nn.BatchNorm2d(out_channels) + elif norm == 'IN': + self.norm = torch.nn.InstanceNorm2d(out_channels) + elif norm == 'GN': + self.norm = torch.nn.GroupNorm(2, out_channels) + elif norm == 'WN': + self.layers = torch.nn.utils.weight_norm(self.layers) + elif norm == 'Adaptive': + self.norm = AdaptiveNorm(n=out_channels) + + if act == 'PReLU': + self.act = torch.nn.PReLU() + elif act == 'SELU': + self.act = torch.nn.SELU(True) + elif act == 'LeakyReLU': + self.act = torch.nn.LeakyReLU(negative_slope=0.02, inplace=True) + elif act == 'ELU': + self.act = torch.nn.ELU(inplace=True) + elif act == 'ReLU': + self.act = torch.nn.ReLU(True) + elif act == 'Tanh': + self.act = torch.nn.Tanh() + elif act == 'Sigmoid': + self.act = torch.nn.Sigmoid() + elif act == 'SoftMax': + self.act = torch.nn.Softmax2d() + + def forward(self, inputs): + + out = self.layers(inputs) + + if self.norm is not None: + out = self.norm(out) + if self.act is not None: + out = self.act(out) + return out + + +class ResBlock(nn.Module): + def __init__(self, num_features): + super(ResBlock, self).__init__() + self.layers = nn.Sequential( + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, act='ReLU', padding=1), + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, padding=1) + ) + + def forward(self, inputs): + return F.relu(self.layers(inputs) + inputs) + + +class DownSample(nn.Module): + def __init__(self, num_features, act, norm, scale=2): + super(DownSample, self).__init__() + if scale == 1: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + else: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + invPixelShuffle(ratio=scale), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + + def forward(self, inputs): + return self.layers(inputs) + + +class ResidualGroup(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act,norm, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [ + ResBlock(n_feat) for _ in range(n_resblocks)] + + modules_body.append(ConvBNReLU2D(n_feat, n_feat, kernel_size, padding=1, act=act, norm=norm)) + self.body = nn.Sequential(*modules_body) + self.re_scale = Scale(1) + + def forward(self, x): + res = self.body(x) + return res + self.re_scale(x) + + + +class FreBlock9(nn.Module): + def __init__(self, channels, args): + super(FreBlock9, self).__init__() + + self.fpre = nn.Conv2d(channels, channels, 1, 1, 0) + self.amp_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.pha_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.post = nn.Conv2d(channels, channels, 1, 1, 0) + + + def forward(self, x): + # print("x: ", x.shape) + _, _, H, W = x.shape + msF = torch.fft.rfft2(self.fpre(x)+1e-8, norm='backward') + + msF_amp = torch.abs(msF) + msF_pha = torch.angle(msF) + # print("msf_amp: ", msF_amp.shape) + amp_fuse = self.amp_fuse(msF_amp) + # print(amp_fuse.shape, msF_amp.shape) + amp_fuse = amp_fuse + msF_amp + pha_fuse = self.pha_fuse(msF_pha) + pha_fuse = pha_fuse + msF_pha + + real = amp_fuse * torch.cos(pha_fuse)+1e-8 + imag = amp_fuse * torch.sin(pha_fuse)+1e-8 + out = torch.complex(real, imag)+1e-8 + out = torch.abs(torch.fft.irfft2(out, s=(H, W), norm='backward')) + out = self.post(out) + out = out + x + out = torch.nan_to_num(out, nan=1e-5, posinf=1e-5, neginf=1e-5) + # print("out: ", out.shape) + return out + +class Attention(nn.Module): + def __init__(self, dim=64, num_heads=8, bias=False): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) + self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias) + self.q = nn.Conv2d(dim, dim , kernel_size=1, bias=bias) + self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x, y): + b, c, h, w = x.shape + + kv = self.kv_dwconv(self.kv(y)) + k, v = kv.chunk(2, dim=1) + q = self.q_dwconv(self.q(x)) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + +class FuseBlock7(nn.Module): + def __init__(self, channels): + super(FuseBlock7, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fre_att = Attention(dim=channels) + self.spa_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + ori = spa + fre = self.fre(fre) + spa = self.spa(spa) + fre = self.fre_att(fre, spa)+fre + spa = self.fre_att(spa, fre)+spa + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class FuseBlock6(nn.Module): + def __init__(self, channels): + super(FuseBlock6, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock7(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock7, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t1_att = Attention(dim=channels) + self.t2_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + t1 = self.t1_att(t1, t2)+t1 + t2 = self.t2_att(t2, t1)+t2 + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock6(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock6, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/ARTUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/ARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ae23fed16b0d9454a5ea09f17b4322f1e9d7e928 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/ARTUnet.py @@ -0,0 +1,751 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + # print("x shape:", x.shape) + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +########################################################################## +class CS_TransformerBlock(nn.Module): + r""" 在ART Transformer Block的基础上添加了channel attention的分支。 + 参考论文: Dual Aggregation Transformer for Image Super-Resolution + 及其代码: https://github.com/zhengchen1999/DAT/blob/main/basicsr/archs/dat_arch.py + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + + self.dwconv = nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim), + nn.InstanceNorm2d(dim), + nn.GELU()) + + self.channel_interaction = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(dim, dim // 8, kernel_size=1), + # nn.InstanceNorm2d(dim // 8), + nn.GELU(), + nn.Conv2d(dim // 8, dim, kernel_size=1)) + + self.spatial_interaction = nn.Sequential( + nn.Conv2d(dim, dim // 16, kernel_size=1), + nn.InstanceNorm2d(dim // 16), + nn.GELU(), + nn.Conv2d(dim // 16, 1, kernel_size=1)) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # print("x before dwconv:", x.shape) + # convolution output + conv_x = self.dwconv(x.permute(0, 3, 1, 2)) + # conv_x = x.permute(0, 3, 1, 2) + # print("x after dwconv:", x.shape) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + attened_x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + attened_x = attened_x[:, :H, :W, :].contiguous() + + attened_x = attened_x.view(B, H * W, C) + + # Adaptive Interaction Module (AIM) + # C-Map (before sigmoid) + channel_map = self.channel_interaction(conv_x).permute(0, 2, 3, 1).contiguous().view(B, 1, C) + # S-Map (before sigmoid) + attention_reshape = attened_x.transpose(-2,-1).contiguous().view(B, C, H, W) + spatial_map = self.spatial_interaction(attention_reshape) + + # C-I + attened_x = attened_x * torch.sigmoid(channel_map) + # S-I + conv_x = torch.sigmoid(spatial_map) * conv_x + conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, H * W, C) + + x = attened_x + conv_x + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/ARTUnet_new.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/ARTUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9d210d46e2a4d73781b3b16b63a53fdc0f25b9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/ARTUnet_new.py @@ -0,0 +1,823 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +Unet的格式构建image restoration 网络。每个transformer block中都是dense self-attention和sparse self-attention交替连接。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True, visualize_attention_maps=False): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + self.visualize_attention_maps = visualize_attention_maps + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + # assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + # qkv: 3, B_, num_heads, N, C // num_heads + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + # B_, num_heads, N, C // num_heads * + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + attn_map = attn.detach().cpu().numpy().mean(axis=1) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) # (B * nP, nHead, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + # attn: B * nP, num_heads, N, N + # v: B * nP, num_heads, N, C // num_heads + # -attn-@-v-> B * nP, num_heads, N, C // num_heads + # -transpose-> B * nP, N, num_heads, C // num_heads + # -reshape-> B * nP, N, C + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, G ** 2, G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, Gh * Gw, Gh * Gw) + else: + x = self.attn(x, Gh, Gw, mask=attn_mask) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +class ConcatTransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + ################## + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + ################## + + x = torch.cat((x, y), dim=1) + + shortcut = x + x = self.norm1(x) + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, 2 * Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, 2 * G ** 2, 2 * G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, 2 * Gh * Gw, 2 * Gh * Gw) + else: + x = self.attn(x, 2 * Gh, Gw, mask=attn_mask) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + output = x + + # print(x.shape, Gh, Gw) + x = output[:, :Gh * Gw, :] + y = output[:, Gh * Gw:, :] + assert x.shape == y.shape + + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = y.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = y.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, y, attn_map + else: + return x, y + + def get_patches(self, x, x_size): + H, W = x_size + B, H, W, C = x.shape + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + return x, attn_mask, (Hd, Wd), (Gh, Gw), (pad_r, pad_b), G, I + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/ART_Restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/ART_Restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..4c23123f0ac1a8e82e2bf907658ce8d91951c3e4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/ART_Restormer.py @@ -0,0 +1,592 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer串联而成。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[4, 4, 4, 4], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + + ######################################### Encoder level1 ############################################ + ### ART encoder block + for layer in self.ART_encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + + ### Restormer encoder block + out_enc_level1 = rearrange(out_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = self.Restormer_encoder_level1(out_enc_level1) + # stack.append(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.ART_encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + out_enc_level2 = rearrange(out_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = self.Restormer_encoder_level2(out_enc_level2) + # stack.append(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.ART_encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + out_enc_level3 = rearrange(out_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = self.Restormer_encoder_level3(out_enc_level3) + # stack.append(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.ART_latent: + latent = layer(latent, [H // 8, W // 8]) + + ### Restormer encoder block + latent = rearrange(latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = self.Restormer_latent(latent) + # stack.append(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + out_dec_level3 = inp_dec_level3 + for layer in self.ART_decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + out_dec_level3 = rearrange(out_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + out_dec_level3 = self.Restormer_decoder_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + out_dec_level2 = inp_dec_level2 + for layer in self.ART_decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + out_dec_level2 = rearrange(out_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + out_dec_level2 = self.Restormer_decoder_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + out_dec_level1 = inp_dec_level1 + for layer in self.ART_decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + ### Restormer decoder block + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_decoder_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_refinement(out_dec_level1) + + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/ART_Restormer_v2.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/ART_Restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..de8f921e81d79cb7af14a75326f0ddaaf3b3f4c3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/ART_Restormer_v2.py @@ -0,0 +1,663 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer并联而成。 +输入特征分别经过ART和Restormer两个分支提取特征, 将两者的特征concat之后,用conv层变换得到该block的输出特征。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.enc_fuse_level1 = nn.Conv2d(int(dim * 2 ** 1), dim, kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.enc_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.enc_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + self.enc_fuse_level4 = nn.Conv2d(int(dim * 2 ** 4), int(dim * 2 ** 3), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.dec_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.dec_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.dec_fuse_level1 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + # self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + # bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + # self.fuse_refinement = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + # def forward(self, inp_img, aux_img): + # stack = [] + # bs, _, H, W = inp_img.shape + + # fuse_img = torch.cat((inp_img, aux_img), 1) + # inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + + def forward(self, inp_img): + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + + ######################################### Encoder level1 ############################################ + art_enc_level1 = inp_enc_level1 + restor_enc_level1 = inp_enc_level1 + + ### ART encoder block + for layer in self.ART_encoder_level1: + art_enc_level1 = layer(art_enc_level1, [H, W]) + + ### Restormer encoder block + restor_enc_level1 = rearrange(restor_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_enc_level1 = self.Restormer_encoder_level1(restor_enc_level1) ### (b, c, h, w) + + art_enc_level1 = rearrange(art_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = torch.cat((art_enc_level1, restor_enc_level1), 1) + out_enc_level1 = self.enc_fuse_level1(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + art_enc_level2 = inp_enc_level2 + restor_enc_level2 = inp_enc_level2 + + for layer in self.ART_encoder_level2: + art_enc_level2 = layer(art_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + restor_enc_level2 = rearrange(restor_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + restor_enc_level2 = self.Restormer_encoder_level2(restor_enc_level2) + + art_enc_level2 = rearrange(art_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = torch.cat((art_enc_level2, restor_enc_level2), 1) + out_enc_level2 = self.enc_fuse_level2(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + art_enc_level3 = inp_enc_level3 + restor_enc_level3 = inp_enc_level3 + + for layer in self.ART_encoder_level3: + art_enc_level3 = layer(art_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + restor_enc_level3 = rearrange(restor_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + restor_enc_level3 = self.Restormer_encoder_level3(restor_enc_level3) + + art_enc_level3 = rearrange(art_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = torch.cat((art_enc_level3, restor_enc_level3), 1) + out_enc_level3 = self.enc_fuse_level3(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + art_latent = inp_enc_level4 + restor_latent = inp_enc_level4 + + for layer in self.ART_latent: + art_latent = layer(art_latent, [H // 8, W // 8]) + + ### Restormer encoder block + restor_latent = rearrange(restor_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + restor_latent = self.Restormer_latent(restor_latent) + + art_latent = rearrange(art_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = torch.cat((art_latent, restor_latent), 1) + latent = self.enc_fuse_level4(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + art_dec_level3 = inp_dec_level3 + restor_dec_level3 = inp_dec_level3 + + for layer in self.ART_decoder_level3: + art_dec_level3 = layer(art_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + restor_dec_level3 = rearrange(restor_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + restor_dec_level3 = self.Restormer_decoder_level3(restor_dec_level3) + + art_dec_level3 = rearrange(art_dec_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_dec_level3 = torch.cat((art_dec_level3, restor_dec_level3), 1) + out_dec_level3 = self.dec_fuse_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + art_dec_level2 = inp_dec_level2 + restor_dec_level2 = inp_dec_level2 + + for layer in self.ART_decoder_level2: + art_dec_level2 = layer(art_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + restor_dec_level2 = rearrange(restor_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + restor_dec_level2 = self.Restormer_decoder_level2(restor_dec_level2) + + art_dec_level2 = rearrange(art_dec_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_dec_level2 = torch.cat((art_dec_level2, restor_dec_level2), 1) + out_dec_level2 = self.dec_fuse_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + art_dec_level1 = inp_dec_level1 + restor_dec_level1 = inp_dec_level1 + + for layer in self.ART_decoder_level1: + art_dec_level1 = layer(art_dec_level1, [H, W]) + + ### Restormer decoder block + restor_dec_level1 = rearrange(restor_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_dec_level1 = self.Restormer_decoder_level1(restor_dec_level1) + + art_dec_level1 = rearrange(art_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = torch.cat((art_dec_level1, restor_dec_level1), 1) + out_dec_level1 = self.dec_fuse_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +# def build_model(args): +# return ART_Restormer( +# inp_channels=2, +# out_channels=1, +# dim=32, +# num_art_blocks=[2, 4, 4, 6], +# num_restormer_blocks=[2, 2, 2, 2], +# num_refinement_blocks=4, +# heads=[1, 2, 4, 8], +# window_size=[10, 10, 10, 10], mlp_ratio=4., +# qkv_bias=True, qk_scale=None, +# interval=[24, 12, 6, 3], +# ffn_expansion_factor = 2.66, +# LayerNorm_type = 'WithBias', ## Other option 'BiasFree' +# bias=False) + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/ARTfuse_layer.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/ARTfuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..78af0b27528b50db897936f6b8b4eee51d566b1c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/ARTfuse_layer.py @@ -0,0 +1,705 @@ +""" +ART layer in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input to the Attention layer:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + # print("q range:", q.max(), q.min()) + # print("k range:", k.max(), k.min()) + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + # print("attention matrix:", attn.shape, attn) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + # print("feature before proj-layer:", x.max(), x.min()) + x = self.proj(x) + # print("feature after proj-layer:", x.max(), x.min()) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pre_norm=True): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + # self.conv = nn.Conv2d(dim, dim, kernel_size=1) + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + # self.conv = ConvBlock(self.dim, self.dim, drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.bn_norm = nn.BatchNorm2d(dim) + self.pre_norm = pre_norm + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + # x = self.conv(x) + shortcut = x + if self.pre_norm: + # print("using pre_norm") + x = self.norm1(x) ## 归一化之后特征范围变得非常小 + x = x.view(B, H, W, C) + # print("normalized input range:", x.max(), x.min(), x.mean()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + # print("range of feature before layer:", x.max(), x.min(), x.mean()) + # x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + + # x_bn = self.bn_norm(x.permute(0, 3, 1, 2)) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + # print("[ART layer] input and output feature difference:", torch.mean(torch.abs(shortcut - x))) + # print("range of feature before ART layer:", shortcut.max(), shortcut.min(), shortcut.mean()) + # print("range of feature after ART layer:", x.max(), x.min(), x.mean()) + if self.pre_norm: + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + # print("using post_norm") + x = self.norm1(shortcut + self.drop_path(x)) + x = self.norm2(x + self.drop_path(self.mlp(x))) + # x = shortcut + self.drop_path(x) + # x = x + self.drop_path(self.mlp(x)) + + # print("range of fused feature:", x.max(), x.min()) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + return self.layers(image) + + + +########################################################################## +class Cross_TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +class Cross_TransformerBlock_v2(nn.Module): + r""" ART Transformer Block. + 将两个模态的特征沿着channel方向concat, 一起取window. 之后把取出来的两个模态中同一个window的所有patches合并,送到transformer处理。 + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + xy = torch.cat((x, y), -1) + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + xy = F.pad(xy, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + xy = xy.reshape(B, Hd // G, G, Wd // G, G, 2*C).permute(0, 1, 3, 2, 4, 5).contiguous() + xy = xy.reshape(B * Hd * Wd // G ** 2, G ** 2, 2*C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + xy = xy.reshape(B, Gh, I, Gw, I, 2*C).permute(0, 2, 4, 1, 3, 5).contiguous() + xy = xy.reshape(B * I * I, Gh * Gw, 2*C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + x = xy[:, :, :C] + y = xy[:, :, C:] + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/DataConsistency.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/DataConsistency.py new file mode 100644 index 0000000000000000000000000000000000000000..5f460e9acabcb33e1a3f812eb9acaeb337da446c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/DataConsistency.py @@ -0,0 +1,48 @@ +""" +Created: DataConsistency @ Xiyang Cai, 2023/09/09 + +Data consistency layer for k-space signal. + +Ref: DataConsistency in DuDoRNet (https://github.com/bbbbbbzhou/DuDoRNet) + +""" + +import torch +from torch import nn +from einops import repeat + + +def data_consistency(k, k0, mask): + """ + k - input in k-space + k0 - initially sampled elements in k-space + mask - corresponding nonzero location + """ + out = (1 - mask) * k + mask * k0 + return out + + +class DataConsistency(nn.Module): + """ + Create data consistency operator + """ + + def __init__(self): + super(DataConsistency, self).__init__() + + def forward(self, k, k0, mask): + """ + k - input in frequency domain, of shape (n, nx, ny, 2) + k0 - initially sampled elements in k-space + mask - corresponding nonzero location (n, 1, len, 1) + """ + + if k.dim() != 4: # input is 2D + raise ValueError("error in data consistency layer!") + + # mask = repeat(mask.squeeze(1, 3), 'b x -> b x y c', y=k.shape[1], c=2) + mask = torch.tile(mask, (1, mask.shape[2], 1, k.shape[-1])) ### [n, 320, 320, 2] + # print("k and k0 shape:", k.shape, k0.shape) + out = data_consistency(k, k0, mask) + + return out, mask diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/DuDo_mUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/DuDo_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a56069a27f0cdd392482b41dc4e3b79252295487 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/DuDo_mUnet.py @@ -0,0 +1,359 @@ +""" +Xiaohan Xing, 2023/10/25 +传入两个模态的kspace data, 经过kspace network进行重建。 +然后把重建之后的kspace data变换回图像,然后送到image domain的网络进行图像重建。 +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + +from .Kspace_mUnet import kspace_mmUnet + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = kspace_ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = kspace_ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = kspace_ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = kspace_ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class DuDo_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 3, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + # self.kspace_net = mmConvKSpace() + self.kspace_net = kspace_mmUnet(args) + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + ) + ) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + stack = [] + feature_stack = [] + # output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + output = torch.cat((image, LR_image, aux_image), 1) #.detach() + + # print("LR image range:", LR_image.max(), LR_image.min()) + # print("kspace_recon image range:", image.max(), image.min()) + # print("aux_image range:", aux_image.max(), aux_image.min()) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output, image, recon_kspace, masked_kspace, data_stats + + + +class kspace_ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return DuDo_mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/DuDo_mUnet_ARTfusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/DuDo_mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6df67b63140ca51bd7b8664151ab4b1b7a6305d5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/DuDo_mUnet_ARTfusion.py @@ -0,0 +1,369 @@ +""" +Xiaohan Xing, 2023/11/08 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + ARTfusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + output1, output2 = aux_image.detach(), LR_image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, image, recon_kspace, masked_kspace, data_stats + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/DuDo_mUnet_CatFusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/DuDo_mUnet_CatFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c2b86b6e2031b1af07ab7cb58efc182b7a2673 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/DuDo_mUnet_CatFusion.py @@ -0,0 +1,338 @@ +""" +Xiaohan Xing, 2023/11/10 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + multi layer concat fusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_CATfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t2_gt: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + t2_image = t2_gt * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + # # print("aux_image range:", aux_image.max(), aux_image.min()) + # print("LR_image range:", LR_image.max(), LR_image.min()) + # print("kspace recon_img range:", image.max(), image.min()) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + t2_gt_image = self.normalize(t2_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + t2_gt_image = torch.clamp(t2_gt_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + # output1, output2 = aux_image.detach(), LR_image.detach() + output1, output2 = aux_image.detach(), image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, aux_image, t2_gt_image, image, LR_image, recon_kspace, masked_kspace, data_stats + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_CATfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/Kspace_ConvNet.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/Kspace_ConvNet.py new file mode 100644 index 0000000000000000000000000000000000000000..918cbe598a62653394c8c4eaf99cd10a45c064ca --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/Kspace_ConvNet.py @@ -0,0 +1,140 @@ +""" +2023/10/05 +Xiaohan Xing +Build a simple network with several convolution layers. +Input: concat (undersampled kspace of FS-PD, fully-sampled kspace of PD modality) +Output: interpolated kspace of FS-PD +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, args, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + # self.conv_blocks = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # self.conv_blocks.append(ConvBlock(self.chans, self.chans * 2, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans * 2, self.chans, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans, self.out_chans, drop_prob)) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + # print("output of the three conv blocks:", output1.shape, output2.shape, output3.shape) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("conv1_output range:", output1.max(), output1.min()) + # print("conv2_output range:", output2.max(), output2.min()) + # print("conv3_output range:", output3.max(), output3.min()) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + # print("output before DC layer:", output.shape) + # print("output range before DC layer:", output.max(), output.min()) + # output = output + kspace ## residual connection + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + # mask_output = output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +def build_model(args): + return mmConvKSpace(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/Kspace_mUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/Kspace_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2db9a357798be4abd4cfbd44e9207555770b0456 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/Kspace_mUnet.py @@ -0,0 +1,202 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # print("output:", output.shape) + # print("kspace:", kspace.shape) + # print("mask:", mask.shape) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/Kspace_mUnet_new.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/Kspace_mUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ccfd28dc198ffeff2259a5fbed45dabdc76819 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/Kspace_mUnet_new.py @@ -0,0 +1,200 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # print("using Unet without Relu layers") + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + # output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # # print("output:", output.shape) + # # print("kspace:", kspace.shape) + # # print("mask:", mask.shape) + # mask_output = mask * output + + return output # .permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/MINet.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..4eeeb063b12be1dc594e8e076e6a59e454fcfa0b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/MINet.py @@ -0,0 +1,356 @@ +""" +Compare with the method "Multi-contrast MRI Super-Resolution via a Multi-stage Integration Network". +The code is from https://github.com/chunmeifeng/MINet/edit/main/fastmri/models/MINet.py. +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F + + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 1 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + # print("modules_tail:", modules_tail) + # self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Sigmoid()) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + + + +class MINet(nn.Module): + def __init__(self, args, n_resgroups=3, n_resblocks=3, n_feats=64): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + # print("x1:", x1.shape, "x2:", x2.shape) + + ### 源代码中是super-resolution任务,所以self.tail对图像进行unsampling. 我们做reconstruction把这步去掉 + x2 = self.tail(x2) + + + resT1 = x1 + resT2 = x2 + + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + # print("t1s:", [item.shape for item in t1s]) + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + # print("resT1 range:", resT1.max(), resT1.min()) + # print("resT2 range:", resT2.max(), resT2.min()) + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1, x2 #x1=pd x2=pdfs + + +def build_model(args): + return MINet(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/MINet_common.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/MINet_common.py new file mode 100644 index 0000000000000000000000000000000000000000..631f3036db4edf8eab00d70c66d2058a3b65cd59 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/MINet_common.py @@ -0,0 +1,90 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias) + +class MeanShift(nn.Conv2d): + def __init__( + self, rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): + + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std + for p in self.parameters(): + p.requires_grad = False + +class BasicBlock(nn.Sequential): + def __init__( + self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, + bn=True, act=nn.ReLU(True)): + + m = [conv(in_channels, out_channels, kernel_size, bias=bias)] + if bn: + m.append(nn.BatchNorm2d(out_channels)) + if act is not None: + m.append(act) + + super(BasicBlock, self).__init__(*m) + +class ResBlock(nn.Module): + def __init__( + self, conv, n_feats, kernel_size, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + + return res + + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + # print("upsampler:", m) + + super(Upsampler, self).__init__(*m) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/SANet.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/SANet.py new file mode 100644 index 0000000000000000000000000000000000000000..bfac3590e248c9532b002885c6c0b7a55ab1400a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/SANet.py @@ -0,0 +1,392 @@ +""" +This script contains the codes of "Exploring Separable Attention for Multi-Contrast MR Image Super-Resolution (TNNLS 2023)" +https://github.com/chunmeifeng/SANet/blob/main/fastmri/models/SANet.py +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import os + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,scale,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups #10 + self.n_resblocks = n_resblocks #20 + self.n_feats = n_feats #64 + kernel_size = 3 + reduction = 16 #16 + # scale = args.scale[0] + + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class Seatt(nn.Module): + def __init__(self, in_c): + super(Seatt, self).__init__() + self.reduce = nn.Conv2d(in_c * 2, 32, 1) + self.ff_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.bf_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.rgbd_pred_layer = Pred_Layer(32 * 2) + self.convq = nn.Conv2d(32, 32, 3, 1, 1) + + def forward(self, rgb_feat, dep_feat, pred): + feat = torch.cat((rgb_feat, dep_feat), 1) + + # vis_attention_map_rgb_feat(rgb_feat[0][0]) + # vis_attention_map_dep_feat(dep_feat[0][0]) + # vis_attention_map_feat(feat[0][0]) + + + feat = self.reduce(feat) + [_, _, H, W] = feat.size() + pred = torch.sigmoid(pred) + + ni_pred = 1 - pred + + ff_feat = self.ff_conv(feat * pred) + bf_feat = self.bf_conv(feat * (1 - pred)) + new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1)) + + return new_pred + +class SANet(nn.Module): + def __init__(self, args, scale=1, n_resgroups=3, n_resblocks=3, n_feats=64): + super(SANet, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + + cs = [64,64,64,64] + self.Seatts = nn.ModuleList([Seatt(c) for c in cs]) + + self.net1 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + print("nlayer:",nlayer) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=3, padding=1) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2, Seatt, map_conv in zip(self.net1.body._modules.items(), self.net2.body._modules.items(), self.Seatts, self.map_convs): + + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + pred = map_conv(resT1) + res = Seatt(resT1,resT2,pred) + + resT2 = res+resT2 + + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + + return x1,x2 + + +def build_model(args): + return SANet(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/SwinFuse_layer.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/SwinFuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3dca8ba793d92961bc3f1d987e211fe542b303 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/SwinFuse_layer.py @@ -0,0 +1,1310 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + # print("x shape:", x.shape) + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + # print("num heads:", self.num_heads) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + # print("x_windows:", x_windows.shape) + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + + # # 不使用残差连接 + # x = self.residual_group_A(x, x_size) + # y = self.residual_group_B(y, x_size) + + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion_layer(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Fusion_num_heads=[6, 6], + window_size=7, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, img_range=1., resi_connection='1conv', + **kwargs): + super(SwinFusion_layer, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + + ### fusion layer. + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + # print("fusion layers:", len(self.layers_Fusion)) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + + def forward_features_Fusion(self, x, y): + + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + # x = torch.cat([x, y], 1) + # ## Downsample the feature in the channel dimension + # x = self.lrelu(self.conv_after_body_Fusion(x)) + # print("x range after fusion:", x.max(), x.min()) + # print("y range after fusion:", y.max(), y.min()) + # print("x:", x.shape) + + return x, y + + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + print("modality1 input:", x.max(), x.min()) + print("modality2 input:", y.max(), y.min()) + + # multi-modal feature fusion. + x, y = self.forward_features_Fusion(x, y) ### 返回两个模态融合之后的features. + print("modality1 output:", x.max(), x.min()) + print("modality2 output:", y.max(), y.min()) + + return x[:, :, :H, :W], y[:, :, :H, :W] + + + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + + +if __name__ == '__main__': + width, height = 120, 120 + swinfuse = SwinFusion_layer(img_size=(width, height), window_size=8, depths=[6, 6, 6], embed_dim=60, num_heads=[6, 6, 6]) + # print("SwinFusion layer:", swinfuse) + + A = torch.randn((4, 60, width, height)) + B = torch.randn((4, 60, width, height)) + + x = swinfuse(A, B) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/SwinFusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/SwinFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..3e72c2a3e8859423f3544043e8b9feaec002d0e2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/SwinFusion.py @@ -0,0 +1,1468 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # print("qkv after reshape:", qkv.shape) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + # 不使用残差连接 + # print("input x shape:", x.shape) + x = self.residual_group_A(x, x_size) + y = self.residual_group_B(y, x_size) + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Re_depths=[4], + Ex_num_heads=[6], Fusion_num_heads=[6, 6], Re_num_heads=[6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinFusion, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + print('in_chans: ', in_chans) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + ####修改shallow feature extraction 网络, 修改为2个3x3的卷积#### + self.conv_first1_A = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first1_B = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first2_A = nn.Conv2d(embed_dim_temp, embed_dim, 3, 1, 1) + self.conv_first2_B = nn.Conv2d(embed_dim_temp, embed_dim_temp, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + self.Re_num_layers = len(Re_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + dpr_Re = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Re_depths))] # stochastic depth decay rule + # build Residual Swin Transformer blocks (RSTB) + self.layers_Ex_A = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_A.append(layer) + self.norm_Ex_A = norm_layer(self.num_features) + + self.layers_Ex_B = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_B.append(layer) + self.norm_Ex_B = norm_layer(self.num_features) + + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + self.layers_Re = nn.ModuleList() + for i_layer in range(self.Re_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Re_depths[i_layer], + num_heads=Re_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Re[sum(Re_depths[:i_layer]):sum(Re_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Re.append(layer) + self.norm_Re = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last1 = nn.Conv2d(embed_dim, embed_dim_temp, 3, 1, 1) + self.conv_last2 = nn.Conv2d(embed_dim_temp, int(embed_dim_temp/2), 3, 1, 1) + self.conv_last3 = nn.Conv2d(int(embed_dim_temp/2), num_out_ch, 3, 1, 1) + + self.tanh = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features_Ex_A(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_A: + x = layer(x, x_size) + + x = self.norm_Ex_A(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward_features_Ex_B(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_B: + x = layer(x, x_size) + + x = self.norm_Ex_B(x) # B L C + x = self.patch_unembed(x, x_size) + return x + + def forward_features_Fusion(self, x, y): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + # print("x token:", x.shape, "y token:", y.shape) + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + x = torch.cat([x, y], 1) + ## Downsample the feature in the channel dimension + x = self.lrelu(self.conv_after_body_Fusion(x)) + + return x + + def forward_features_Re(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Re: + x = layer(x, x_size) + + x = self.norm_Re(x) # B L C + x = self.patch_unembed(x, x_size) + # Convolution + x = self.lrelu(self.conv_last1(x)) + x = self.lrelu(self.conv_last2(x)) + x = self.conv_last3(x) + return x + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + # self.mean_A = self.mean.type_as(x) + # self.mean_B = self.mean.type_as(y) + # self.mean = (self.mean_A + self.mean_B) / 2 + + # x = (x - self.mean_A) * self.img_range + # y = (y - self.mean_B) * self.img_range + + # Feedforward + x = self.forward_features_Ex_A(x) + y = self.forward_features_Ex_B(y) + # print("x before fusion:", x.shape) + # print("y before fusion:", y.shape) + x = self.forward_features_Fusion(x, y) + x = self.forward_features_Re(x) + x = self.tanh(x) + + # x = x / self.img_range + self.mean + # print("H:", H, "upscale:", self.upscale) + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='convs') + + +if __name__ == '__main__': + model = SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='convs') + # print(model) + + A = torch.randn((1, 1, 240, 240)) + B = torch.randn((1, 1, 240, 240)) + + x = model(A, B) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/TransFuse.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/TransFuse.py new file mode 100644 index 0000000000000000000000000000000000000000..56ca3d2b24f382603c876e4f6909363a9794ffa3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/TransFuse.py @@ -0,0 +1,155 @@ +import math +from collections import deque + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from torchvision import models + + +class SelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + """ + + def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(n_embd, n_embd) + self.query = nn.Linear(n_embd, n_embd) + self.value = nn.Linear(n_embd, n_embd) + # regularization + self.attn_drop = nn.Dropout(attn_pdrop) + self.resid_drop = nn.Dropout(resid_pdrop) + # output projection + self.proj = nn.Linear(n_embd, n_embd) + self.n_head = n_head + + def forward(self, x): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y + + +class Block(nn.Module): + """ an unassuming Transformer block """ + + def __init__(self, n_embd, n_head, block_exp, attn_pdrop, resid_pdrop): + super().__init__() + self.ln1 = nn.LayerNorm(n_embd) + self.ln2 = nn.LayerNorm(n_embd) + self.attn = SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + self.mlp = nn.Sequential( + nn.Linear(n_embd, block_exp * n_embd), + nn.ReLU(True), # changed from GELU + nn.Linear(block_exp * n_embd, n_embd), + nn.Dropout(resid_pdrop), + ) + + def forward(self, x): + B, T, C = x.size() + + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + + return x + + +class TransFuse_layer(nn.Module): + """ the full GPT language model, with a context size of block_size """ + + def __init__(self, n_embd, n_head, block_exp, n_layer, + num_anchors, seq_len=1, + embd_pdrop=0.1, attn_pdrop=0.1, resid_pdrop=0.1): + super().__init__() + self.n_embd = n_embd + self.seq_len = seq_len + self.vert_anchors = num_anchors + self.horz_anchors = num_anchors + + # positional embedding parameter (learnable), image + lidar + self.pos_emb = nn.Parameter(torch.zeros(1, 2 * seq_len * self.vert_anchors * self.horz_anchors, n_embd)) + + self.drop = nn.Dropout(embd_pdrop) + + # transformer + self.blocks = nn.Sequential(*[Block(n_embd, n_head, + block_exp, attn_pdrop, resid_pdrop) + for layer in range(n_layer)]) + + # decoder head + self.ln_f = nn.LayerNorm(n_embd) + + self.block_size = seq_len + self.apply(self._init_weights) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + + def forward(self, m1, m2): + """ + Args: + m1 (tensor): B*seq_len, C, H, W + m2 (tensor): B*seq_len, C, H, W + """ + + bz = m2.shape[0] // self.seq_len + h, w = m2.shape[2:4] + + # forward the image model for token embeddings + m1 = m1.view(bz, self.seq_len, -1, h, w) + m2 = m2.view(bz, self.seq_len, -1, h, w) + + # pad token embeddings along number of tokens dimension + token_embeddings = torch.cat([m1, m2], dim=1).permute(0,1,3,4,2).contiguous() + token_embeddings = token_embeddings.view(bz, -1, self.n_embd) # (B, an * T, C) + + # add (learnable) positional embedding for all tokens + x = self.drop(self.pos_emb + token_embeddings) # (B, an * T, C) + x = self.blocks(x) # (B, an * T, C) + x = self.ln_f(x) # (B, an * T, C) + x = x.view(bz, 2 * self.seq_len, self.vert_anchors, self.horz_anchors, self.n_embd) + x = x.permute(0,1,4,2,3).contiguous() # same as token_embeddings + + m1_out = x[:, :self.seq_len, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + m2_out = x[:, self.seq_len:, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + print("modality1 output:", m1_out.max(), m1_out.min()) + print("modality2 output:", m2_out.max(), m2_out.min()) + + return m1_out, m2_out + + +if __name__ == "__main__": + + feature1 = torch.randn((4, 512, 20, 20)) + feature2 = torch.randn((4, 512, 20, 20)) + model = TransFuse_layer(n_embd=512, n_head=4, block_exp=4, n_layer=8, num_anchors=20, seq_len=1) + print("TransFuse_layer:", model) + feat1, feat2 = model(feature1, feature2) + print(feat1.shape, feat2.shape) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/Unet_ART.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/Unet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b0f834270a921ae4eb428c92b3e782fbed19b3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/Unet_ART.py @@ -0,0 +1,243 @@ +""" +2023/07/06, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + output = conv_layer(output) + + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..334f1821a9b59e692641f70cdc61e7655b1a9c89 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/__init__.py @@ -0,0 +1,114 @@ +from .unet import build_model as UNET +from .mmunet import build_model as MUNET +from .humus_net import build_model as HUMUS +from .MINet import build_model as MINET +from .SANet import build_model as SANET +from .mtrans_net import build_model as MTRANS +from .unimodal_transformer import build_model as TRANS +from .cnn_transformer_new import build_model as CNN_TRANS +from .cnn import build_model as CNN +from .unet_transformer import build_model as UNET_TRANSFORMER +from .swin_transformer import build_model as SWIN_TRANS +from .restormer import build_model as RESTORMER + + +from .cnn_late_fusion import build_model as MULTI_CNN +from .mmunet_late_fusion import build_model as MMUNET_LATE +from .mmunet_early_fusion import build_model as MMUNET_EARLY +from .trans_unet.trans_unet import build_model as TRANSUNET +from .SwinFusion import build_model as SWINMULTI +from .mUnet_transformer import build_model as MUNET_TRANSFORMER +from .munet_multi_transfuse import build_model as MUNET_TRANSFUSE +from .munet_swinfusion import build_model as MUNET_SWINFUSE +from .munet_multi_concat import build_model as MUNET_CONCAT +from .munet_concat_decomp import build_model as MUNET_CONCAT_DECOMP +from .munet_multi_sum import build_model as MUNET_SUM +# from .mUnet_ARTfusion_new import build_model as MUNET_ART_FUSION +from .mUnet_ARTfusion import build_model as MUNET_ART_FUSION +from .mUnet_Restormer_fusion import build_model as MUNET_RESTORMER_FUSION +from .mUnet_ARTfusion_SeqConcat import build_model as MUNET_ART_FUSION_SEQ + +from .ARTUnet import build_model as ART +from .Unet_ART import build_model as UNET_ART +from .unet_restormer import build_model as UNET_RESTORMER +from .mmunet_ART import build_model as MMUNET_ART +from .mmunet_restormer import build_model as MMUNET_RESTORMER +from .mmunet_restormer_v2 import build_model as MMUNET_RESTORMER_V2 +from .mmunet_restormer_3blocks import build_model as MMUNET_RESTORMER_SMALL + +from .mARTUnet import build_model as ART_MULTI_INPUT + +from .ART_Restormer import build_model as ART_RESTORMER +from .ART_Restormer_v2 import build_model as ART_RESTORMER_V2 +from .swinIR import build_model as SWINIR + +from .Kspace_ConvNet import build_model as KSPACE_CONVNET +from .Kspace_mUnet_new import build_model as KSPACE_MUNET +from .kspace_mUnet_concat import build_model as KSPACE_MUNET_CONCAT +from .kspace_mUnet_AttnFusion import build_model as KSPACE_MUNET_ATTNFUSE + +from .DuDo_mUnet import build_model as DUDO_MUNET +from .DuDo_mUnet_ARTfusion import build_model as DUDO_MUNET_ARTFUSION +from .DuDo_mUnet_CatFusion import build_model as DUDO_MUNET_CONCAT + + +model_factory = { + 'unet_single': UNET, + 'humus_single': HUMUS, + 'transformer_single': TRANS, + 'cnn_transformer': CNN_TRANS, + 'cnn_single': CNN, + 'swin_trans_single': SWIN_TRANS, + 'trans_unet': TRANSUNET, + 'unet_transformer': UNET_TRANSFORMER, + 'restormer': RESTORMER, + 'unet_art': UNET_ART, + 'unet_restormer': UNET_RESTORMER, + 'art': ART, + 'art_restormer': ART_RESTORMER, + 'art_restormer_v2': ART_RESTORMER_V2, + 'swinIR': SWINIR, + + + 'munet_transformer': MUNET_TRANSFORMER, + 'munet_transfuse': MUNET_TRANSFUSE, + 'cnn_late_multi': MULTI_CNN, + 'unet_multi': MUNET, + 'unet_late_multi':MMUNET_LATE, + 'unet_early_multi':MMUNET_EARLY, + 'munet_ARTfusion': MUNET_ART_FUSION, + 'munet_restormer_fusion': MUNET_RESTORMER_FUSION, + + + 'minet_multi': MINET, + 'sanet_multi': SANET, + 'mtrans_multi': MTRANS, + 'swin_fusion': SWINMULTI, + 'munet_swinfuse': MUNET_SWINFUSE, + 'munet_concat': MUNET_CONCAT, + 'munet_concat_decomp': MUNET_CONCAT_DECOMP, + 'munet_sum': MUNET_SUM, + 'mmunet_art': MMUNET_ART, + 'mmunet_restormer': MMUNET_RESTORMER, + 'mmunet_restormer_small': MMUNET_RESTORMER_SMALL, + 'mmunet_restormer_v2': MMUNET_RESTORMER_V2, + + 'art_multi_input': ART_MULTI_INPUT, + + 'munet_ARTfusion_SeqConcat': MUNET_ART_FUSION_SEQ, + + 'kspace_ConvNet': KSPACE_CONVNET, + 'kspace_munet': KSPACE_MUNET, + 'kspace_munet_concat': KSPACE_MUNET_CONCAT, + 'kspace_munet_AttnFusion': KSPACE_MUNET_ATTNFUSE, + + 'dudo_munet': DUDO_MUNET, + 'dudo_munet_ARTfusion': DUDO_MUNET_ARTFUSION, + 'dudo_munet_concat': DUDO_MUNET_CONCAT, +} + + +def build_model_from_name(args): + assert args.model_name in model_factory.keys(), 'unknown model name' + + return model_factory[args.model_name](args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/cnn.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5c130178e7cdabaaef8686da0cec341265f3c43d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/cnn.py @@ -0,0 +1,246 @@ +""" +把our method中的单模态CNN_Transformer中的transformer换成几个conv_block用来下采样提取特征, 看是否能够达到和Unet相似的效果。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Reconstructor(nn.Module): + def __init__(self): + super(CNN_Reconstructor, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + nlatent = self.encoder.output_dim + self.decoder = Decoder(n_upsample=4, n_res=1, dim=nlatent, output_dim=1, pad_type='zero') + + + def forward(self, m): + #4,1,240,240 + feature_output = self.encoder.model(m) # [4, 256, 60, 60] + # print("output of the encoder:", feature_output.shape) + + output = self.decoder(feature_output) + + return output + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Reconstructor() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/cnn_late_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/cnn_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..58dacad6910b6e896d00a0d36c9ea60e8d77217f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/cnn_late_fusion.py @@ -0,0 +1,196 @@ +""" +在lequan的代码上,去掉transformer-based feature fusion部分。 +直接用CNN提取特征之后concat作为multi-modal representation. 用普通的decoder进行图像重建。 +把T1作为guidance modality, T2作为target modality. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers1 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + self.down_sample_layers2 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers1.append(ConvBlock(ch, ch * 2, drop_prob)) + self.down_sample_layers2.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + + # self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # ch = chans + # for _ in range(num_pool_layers - 1): + # self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + # ch *= 2 + # self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv = nn.Sequential(nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), nn.Tanh()) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = torch.cat([image, aux_image], 1) + # for layer in self.down_sample_layers: + # output = layer(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + + # apply down-sampling layers + output = image + for layer in self.down_sample_layers1: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + output_m1 = output + + output = aux_image + for layer in self.down_sample_layers2: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + output_m2 = output + + # print("two modality outputs:", output_m1.shape, output.shape) + output = torch.cat([output_m1, output_m2], 1) + + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv in self.up_transpose_conv: + output = transpose_conv(output) + + output = self.up_conv(output) + # print("model output:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/cnn_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/cnn_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..261d4d214e35c3a9be5d28bc426ce7280124723c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/cnn_transformer.py @@ -0,0 +1,289 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=2, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=2, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 60 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 4 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + feature_output = self.upsampling1(feature_output) # [4,512,30,30] + feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/cnn_transformer_new.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/cnn_transformer_new.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b2dea1e7cc8d456a684b6c839f052383999e84 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/cnn_transformer_new.py @@ -0,0 +1,290 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +2023/05/23: 之前是从尺寸为60*60的特征图上切patch, 现在可以尝试直接用conv_blocks提取15*15的特征图,然后按照patch_size = 1切成225个patches。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=4, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 1 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + # feature_output = self.upsampling1(feature_output) # [4,512,30,30] + # feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/humus_net.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/humus_net.py new file mode 100644 index 0000000000000000000000000000000000000000..d23701634f849b414866d71b3c23c6d07f3d6437 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/humus_net.py @@ -0,0 +1,1070 @@ +""" +HUMUS-Block +Hybrid Transformer-convolutional Multi-scale denoiser. + +We use parts of the code from +SwinIR (https://github.com/JingyunLiang/SwinIR/blob/main/models/network_swinir.py) +Swin-Unet (https://github.com/HuCaoFighting/Swin-Unet/blob/main/networks/swin_transformer_unet_skip_expand_decoder_sys.py) + +HUMUS-Net中是对kspace data操作, 多个unrolling cascade, 每个cascade中都利用HUMUS-block对图像进行denoising. +因为我们的方法输入是Under-sampled images, 所以不需要进行unrolling, 直接使用一个HUMUS-block进行MRI重建。 +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from einops import rearrange + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # print("x shape:", x.shape) + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + +class PatchExpandSkip(nn.Module): + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.expand = nn.Linear(dim, 2*dim, bias=False) + self.channel_mix = nn.Linear(dim, dim // 2, bias=False) + self.norm = norm_layer(dim // 2) + + def forward(self, x, skip): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) + x = x.view(B,-1,C//4) + x = torch.cat([x, skip], dim=-1) + x = self.channel_mix(x) + x = self.norm(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, torch.tensor(x_size)) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + block_type: + D: downsampling block, + U: upsampling block + B: bottleneck block + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv', block_type='B'): + super(RSTB, self).__init__() + self.dim = dim + conv_dim = dim // (patch_size ** 2) + divide_out_ch = 1 + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + ) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(conv_dim, conv_dim // divide_out_ch, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(conv_dim, conv_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // divide_out_ch, 3, 1, 1)) + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.block_type = block_type + if block_type == 'B': + self.reshape = torch.nn.Identity() + elif block_type == 'D': + self.reshape = PatchMerging(input_resolution, dim) + elif block_type == 'U': + self.reshape = PatchExpandSkip([res // 2 for res in input_resolution], dim * 2) + else: + raise ValueError('Unknown RSTB block type.') + + def forward(self, x, x_size, skip=None): + if self.block_type == 'U': + assert skip is not None, "Skip connection is required for patch expand" + x = self.reshape(x, skip) + + out = self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size)) + block_out = self.patch_embed(out) + x + + if self.block_type == 'D': + block_out = (self.reshape(block_out), block_out) # return skip connection + + return block_out + +class PatchEmbed(nn.Module): + r""" Fixed Image to Patch Embedding. Flattens image along spatial dimensions + without learned projection mapping. Only supports 1x1 patches. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + embed_dim (int): Number of output channels. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + +class PatchEmbedLearned(nn.Module): + r""" Image to Patch Embedding with arbitrary patch size and learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Linear(in_chans * patch_size[0] * patch_size[1], embed_dim) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.patch_to_vec(x) + x = self.proj(x) + if self.norm is not None: + x = self.norm(x) + return x + + def patch_to_vec(self, x): + """ + Args: + x: (B, C, H, W) + Returns: + patches: (B, num_patches, patch_size * patch_size * C) + """ + B, C, H, W = x.shape + x = x.permute(0, 2, 3, 1).contiguous() + x = x.view(B, H // self.patch_size[0], self.patch_size[0], W // self.patch_size[1], self.patch_size[1], C) + patches = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.patch_size[0], self.patch_size[1], C) + patches = patches.contiguous().view(B, self.num_patches, self.patch_size[0] * self.patch_size[1] * C) + return patches + + +class PatchUnEmbed(nn.Module): + r""" Fixed Image to Patch Unembedding via reshaping. Only supports patch size 1x1. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + +class PatchUnEmbedLearned(nn.Module): + r""" Patch Embedding to Image with learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans=None, out_chans=None, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.out_chans = out_chans + + self.proj_up = nn.Linear(in_chans, in_chans * patch_size[0] * patch_size[1]) + + def forward(self, x, x_size=None): + B, HW, C = x.shape + x = self.proj_up(x) + x = self.vec_to_patch(x) + return x + + def vec_to_patch(self, x): + B, HW, C = x.shape + x = x.view(B, self.patches_resolution[0], self.patches_resolution[1], C).contiguous() + x = x.view(B, self.patches_resolution[0], self.patch_size[0], self.patches_resolution[1], self.patch_size[1], C // (self.patch_size[0] * self.patch_size[1])).contiguous() + x = x.view(B, self.img_size[0], self.img_size[1], -1).contiguous().permute(0, 3, 1, 2) + return x + + +class HUMUSNet(nn.Module): + r""" HUMUS-Block + Args: + img_size (int | tuple(int)): Input image size. + patch_size (int | tuple(int)): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depths (tuple(int)): Depth of each Swin Transformer layer in encoder and decoder paths. + num_heads (tuple(int)): Number of attention heads in different layers of encoder and decoder. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + ape (bool): If True, add absolute position embedding to the patch embedding. + patch_norm (bool): If True, add normalization after patch embedding. + use_checkpoint (bool): Whether to use checkpointing to save memory. + img_range: Image range. 1. or 255. + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + conv_downsample_first: use convolutional downsampling before MUST to reduce compute load on Transformers + """ + + def __init__(self, args, + img_size=[240, 240], + in_chans=1, + patch_size=1, + embed_dim=66, + depths=[2, 2, 2], + num_heads=[3, 6, 12], + window_size=5, + mlp_ratio=2., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + img_range=1., + resi_connection='1conv', + bottleneck_depth=2, + bottleneck_heads=24, + conv_downsample_first=True, + out_chans=1, + no_residual_learning=False, + **kwargs): + super(HUMUSNet, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans if out_chans is None else out_chans + + self.center_slice_out = (out_chans == 1) + + self.img_range = img_range + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + self.conv_downsample_first = conv_downsample_first + self.no_residual_learning = no_residual_learning + + ##################################################################################################### + ################################### 1, input block ################################### + input_conv_dim = embed_dim + self.conv_first = nn.Conv2d(num_in_ch, input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** self.num_layers) + self.mlp_ratio = mlp_ratio + + # Downsample for low-res feature extraction + if self.conv_downsample_first: + img_size = [im //2 for im in img_size] + self.conv_down_block = ConvBlock(input_conv_dim // 2, input_conv_dim, 0.0) + self.conv_down = DownsampConvBlock(input_conv_dim, input_conv_dim) + + # split image into non-overlapping patches + if patch_size > 1: + self.patch_embed = PatchEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + if patch_size > 1: + self.patch_unembed = PatchUnEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, out_chans=embed_dim, norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # Build MUST + # encoder + self.layers_down = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** i_layer) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='D', + ) + self.layers_down.append(layer) + + # bottleneck + dim_scaler = (2 ** self.num_layers) + self.layer_bottleneck = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=bottleneck_depth, + num_heads=bottleneck_heads, + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='B', + ) + + # decoder + self.layers_up = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** (self.num_layers - i_layer - 1)) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[(self.num_layers-1-i_layer)], + num_heads=num_heads[(self.num_layers-1-i_layer)], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='U', + ) + self.layers_up.append(layer) + + self.norm_down = norm_layer(self.num_features) + self.norm_up = norm_layer(self.embed_dim) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(input_conv_dim, input_conv_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(input_conv_dim, input_conv_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim, 3, 1, 1)) + + # Upsample if needed + if self.conv_downsample_first: + self.conv_up_block = ConvBlock(input_conv_dim, input_conv_dim // 2, 0.0) + self.conv_up = TransposeConvBlock(input_conv_dim, input_conv_dim // 2) + + ##################################################################################################### + ################################ 3, output block ################################ + self.conv_last = nn.Conv2d(input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, num_out_ch, 3, 1, 1) + self.output_norm = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + + # divisible by window size + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + + # divisible by total downsampling ratio + # this could be done more efficiently by combining the two + total_downsamp = int(2 ** (self.num_layers - 1)) + pad_h = (total_downsamp - h % total_downsamp) % total_downsamp + pad_w = (total_downsamp - w % total_downsamp) % total_downsamp + x = F.pad(x, (0, pad_w, 0, pad_h)) + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + # encode + skip_cons = [] + for i, layer in enumerate(self.layers_down): + x, skip = layer(x, x_size=layer.input_resolution) + skip_cons.append(skip) + x = self.norm_down(x) # B L C + + # bottleneck + x = self.layer_bottleneck(x, self.layer_bottleneck.input_resolution) + + # decode + for i, layer in enumerate(self.layers_up): + x = layer(x, x_size=layer.input_resolution, skip=skip_cons[-i-1]) + x = self.norm_up(x) + x = self.patch_unembed(x, x_size) + return x + + def forward(self, x): + # print("input shape:", x.shape) + # print("input range:", x.max(), x.min()) + C, H, W = x.shape[1:] + center_slice = (C - 1) // 2 + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.conv_downsample_first: + x_first = self.conv_first(x) + x_down = self.conv_down(self.conv_down_block(x_first)) + res = self.conv_after_body(self.forward_features(x_down)) + res = self.conv_up(res) + res = torch.cat([res, x_first], dim=1) + res = self.conv_up_block(res) + + res = self.conv_last(res) + + if self.no_residual_learning: + x = res + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + res + else: + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + + if self.no_residual_learning: + x = self.conv_last(res) + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + self.conv_last(res) + + # print("output range before scaling:", x.max(), x.min()) + # print("self.mean:", self.mean) + + if self.center_slice_out: + x = x / self.img_range + self.mean[:, center_slice, ...].unsqueeze(1) + else: + x = x / self.img_range + self.mean + + x = self.output_norm(x) + + return x[:, :, :H, :W] + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + +class DownsampConvBlock(nn.Module): + """ + A Downsampling Convolutional Block that consists of one strided convolution + layer followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.Conv2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H/2, W/2)`. + """ + return self.layers(image) + + + +def build_model(args): + return HUMUSNet(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/kspace_mUnet_AttnFusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/kspace_mUnet_AttnFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..02588c24559c0604ddbbd03a14b3b32e6cff2c8d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/kspace_mUnet_AttnFusion.py @@ -0,0 +1,300 @@ +""" +Xiaohan Xing, 2023/11/22 +两个模态分别用CNN提取多个层级的特征, 每个层级都把特征沿着宽度拼接起来,得到(h, 2w, c)的特征。 +然后学习得到(h, 2w)的attention map, 和原始的特征相乘送到后面的层。 +""" + +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.attn_layers = nn.ModuleList() + + for l in range(self.num_pool_layers): + + self.attn_layers.append(nn.Sequential(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob), + nn.Conv2d(chans*(2**l), 2, kernel_size=1, stride=1), + nn.Sigmoid())) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, attn_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.attn_layers): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat之后求attention, 然后进行modulation. + attn = attn_layer(torch.cat((output1, output2), 1)) + # print("spatial attention:", attn[:, 0, :, :].unsqueeze(1).shape) + # print("output1:", output1.shape) + output1 = output1 * (1 + attn[:, 0, :, :].unsqueeze(1)) + output2 = output2 * (1 + attn[:, 1, :, :].unsqueeze(1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/kspace_mUnet_concat.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/kspace_mUnet_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..d16bfdc83305d3e1f490e5ca4b445c6d730557ce --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/kspace_mUnet_concat.py @@ -0,0 +1,351 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Predict the low-frequency mask ############## +class Mask_Predictor(nn.Module): + def __init__(self, in_chans, out_chans, chans, drop_prob): + super(Mask_Predictor, self).__init__() + self.conv1 = ConvBlock(in_chans, chans, drop_prob) + self.conv2 = ConvBlock(chans, chans*2, drop_prob) + self.conv3 = ConvBlock(chans*2, chans*4, drop_prob) + + # self.conv4 = ConvBlock(chans*4, 1, drop_prob) + + self.FC = nn.Linear(chans*4, out_chans) + self.act = nn.Sigmoid() + + + def forward(self, x): + ### 三层卷积和down-sampling. + # print("[Mask predictor] input data range:", x.max(), x.min()) + output = F.avg_pool2d(self.conv1(x), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer1_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv2(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer2_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv3(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer3_feature range:", output.max(), output.min()) + + feature = F.adaptive_avg_pool2d(F.relu(output), (1, 1)).squeeze(-1).squeeze(-1) + # print("mask_prediction feature:", output[:, :5]) + # print("[Mask predictor] bottleneck feature range:", feature.max(), feature.min()) + + output = self.FC(feature) + # print("output before act:", output) + output = self.act(output) + # print("mask output:", output) + + return output + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.mask_predictor1 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + self.mask_predictor2 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + ### 预测做low-freq DC的mask区域. + ### 输出三个维度, 前两维是坐标,最后一维是mask内部的权重 + t1_mask = self.mask_predictor1(output1) + t2_mask = self.mask_predictor2(output2) + + # print("t1_mask and t2_mask:", t1_mask.shape, t2_mask.shape) + t1_mask_coords, t2_mask_coords = t1_mask[:, :2], t2_mask[:, :2] + t1_DC_weight, t2_DC_weight = t1_mask[:, -1], t2_mask[:, -1] + + + # ### 预测做low-freq DC的DC weight + # t1_DC_weight = self.mask_predictor1(output1) + # t2_DC_weight = self.mask_predictor2(output2) + + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, t1_mask_coords, t2_mask_coords, t1_DC_weight, t2_DC_weight ## + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mARTUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9607d41ca14562a94e542a1804af5aeaf15bc123 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mARTUnet.py @@ -0,0 +1,563 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +2023/07/20, Xiaohan Xing +将两个模态的图像在输入端concat, 得到num_channels = 2的input, 然后送到ART模型进行T2 modality的重建。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + # print("feature after layer_norm:", x.max(), x.min()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=2, + out_channels=1, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img, aux_img): + stack = [] + bs, _, H, W = inp_img.shape + fuse_img = torch.cat((inp_img, aux_img), 1) + inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=2, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mUnet_ARTfusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..610b1e1e38d74f695a3a19f22a6a80f3152bf713 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mUnet_ARTfusion.py @@ -0,0 +1,305 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import CS_TransformerBlock as TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + # output1, output2 = image1, image2 + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mUnet_ARTfusion_SeqConcat.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mUnet_ARTfusion_SeqConcat.py new file mode 100644 index 0000000000000000000000000000000000000000..55280925415f1db47cf746b59fb9710d47c9bff2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mUnet_ARTfusion_SeqConcat.py @@ -0,0 +1,374 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any + +import matplotlib.pyplot as plt +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet_new import ConcatTransformerBlock, TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion_SeqConcat(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4. + ): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.vis_feature_maps = False + self.vis_attention_maps = False + # self.vis_feature_maps = args.MODEL.VIS_FEAT_MAPS + # self.vis_attention_maps = args.MODEL.VIS_ATTN_MAPS + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + if l < 2: + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * (2 ** (l + 1))), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + else: + self.fuse_layers.append(nn.ModuleList([ConcatTransformerBlock(dim=int(chans * (2 ** l)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + feature_maps = { + 'modal1': {'before': [], 'after': []}, + 'modal2': {'before': [], 'after': []} + } + + attention_maps = [] + """ + attention_maps = [ + unet layer 1 attention maps (x2): [map from TransBlock1, TransBlock2, ...], + unet layer 2 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 3 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 4 attention maps (x6): [map from TransBlock1, TransBlock2, ...], + ] + """ + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + if self.vis_feature_maps: + feature_maps['modal1']['before'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['before'].append(output2.detach().cpu().numpy()) + + ### 两个模态的特征concat + # output = torch.cat((output1, output2), 1) + + + + unet_layer_attn_maps = [] + """ + unet_layer_attn_maps: [ + attn_map from TransformerBlock1: (B, nHGroups, nWGroups, winSize, winSize), + attn_map from TransformerBlock2: (B, nHGroups, nWGroups, winSize, winSize), + ... + ] + """ + if l < 2: + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + else: + output1 = rearrange(output1, "b c h w -> b (h w) c").contiguous() + output2 = rearrange(output2, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + if l < 2: + if self.vis_attention_maps: + output, attn_map = layer(output, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output = layer(output, [H // (2**l), W // (2**l)]) + + else: + if self.vis_attention_maps: + output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + attention_maps.append(unet_layer_attn_maps) + + if l < 2: + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + else: + output1 = rearrange(output1, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output2 = rearrange(output2, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + # if self.vis_attention_maps: + # output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) # attn_map: (B, nHGroups, nWGroups, winSize, winSize) + # unet_layer_attn_maps.append(attn_map) + # else: + # output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + # attention_maps.append(unet_layer_attn_maps) + + + + if self.vis_feature_maps: + feature_maps['modal1']['after'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['after'].append(output2.detach().cpu().numpy()) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + if self.vis_feature_maps: + return output1, output2, feature_maps#, relation_stack1, relation_stack2 #, t1_features, t2_features + elif self.vis_attention_maps: + return output1, output2, attention_maps + else: + return output1, output2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion_SeqConcat(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mUnet_Restormer_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mUnet_Restormer_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8ab29b641fa0231ad0d163224a4a90b44c32a9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mUnet_Restormer_fusion.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .restormer import TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8]): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.Sequential(*[TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[0], ffn_expansion_factor=2.66, \ + bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + ### 将encoder中multi-modal fusion之后的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = fuse_layer(output) + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mUnet_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mUnet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2a209758203b198502154b75496333ef91717a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mUnet_transformer.py @@ -0,0 +1,387 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/06/08 +分别用CNN提取两个模态的特征,然后送到transformer进行特征交互, 将经过transformer提取的特征和之前CNN提取的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Decoder_wo_skip(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder_wo_skip, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + output = transpose_conv(output) + output = conv(output) + + return output + + + + + +class mUnet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch*2 + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1, output2 = image1, image2 + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack1, output1 = self.encoder1(output1) ### output size: [4, 256, 15, 15] + stack2, output2 = self.encoder2(output2) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output1) # [4, 225, 256], 225 is the number of patches. + patch_embed_input1 = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + n_embed = self.to_patch_embedding(output2) # [4, 225, 256], 225 is the number of patches. + patch_embed_input2 = F.normalize(n_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input1, patch_embed_input2), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1]/2)), int(np.sqrt(feature_output.shape[1]/2)) + patch_num = feature_output.shape[1] + feature_output1 = feature_output[:, :patch_num//2, :] + feature_output2 = feature_output[:, patch_num//2:, :] + feature_output1 = feature_output1.contiguous().view(b, feature_output1.shape[-1], h, w) # [4,c,15,15] + feature_output2 = feature_output2.contiguous().view(b, feature_output2.shape[-1], h, w) + + output = torch.cat((output1, output2), 1) + torch.cat((feature_output1, feature_output2), 1) + # print("CNN feature range:", output1.max(), output1.min()) + # print("Transformer feature range:", feature_output1.max(), feature_output1.min()) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_Transformer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ee0c4784d04638df4ca259613777dde34980a2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet.py @@ -0,0 +1,201 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/04 +Concatenate the input images from different modalities and input it into the UNet. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_ART.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..1a09f2900a8924cdc1a7c373270e6ff8aadcfa45 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_ART.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_ART_v2.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_ART_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..adfe6bc0d9192a126f39932b40443badb5b60068 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_ART_v2.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_early_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_early_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..d062f057a2080a471005125a00ee27453a4b0706 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_early_fusion.py @@ -0,0 +1,223 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +encapsulate the encoder and decoder for the Unet model. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + stack, output = self.encoder(output) + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + output = self.decoder(output, stack) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_late_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..2b48e7d4445afd6a5bc1eaa2b065c37431585e94 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_late_fusion.py @@ -0,0 +1,229 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +use Unet for the reconstruction of each modality, concat the intermediate features of both modalities. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("decoder output:", output.shape) + + return output + + + + +class mmUnet_late(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack1, output1 = self.encoder1(image) + stack2, output2 = self.encoder2(aux_image) + ### fuse two modalities to get multi-modal representation. + output = torch.cat([output1, output2], 1) + output = self.conv(output) ### from [4, 512, 15, 15] to [4, 512, 15, 15] + + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet_late(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..88c7920d65c58ab82a00309a1ae501a8d5b33804 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_restormer.py @@ -0,0 +1,237 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_restormer_3blocks.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_restormer_3blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e3afe68612727d3195872a121ec8040fb0f66ef2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_restormer_3blocks.py @@ -0,0 +1,235 @@ +""" +2023/07/18, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +之前 3 blocks的模型深层特征存在严重的gradient vanishing问题。现在把模型的encoder和decoder都改成只有3个blocks, 看是否还存在gradient vanishing问题。 +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 3, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2] + heads=[1, 2, 4] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_restormer_v2.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..4b51d3a218057206863230973d557173e891d3cd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mmunet_restormer_v2.py @@ -0,0 +1,220 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +这个版本中是将Unet encoder提取的特征直接连接到decoder, 经过restormer refine的特征只是传递到下一层的encoder block. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + output = feature + + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mtrans_mca.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mtrans_mca.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4c89df4af71e986fa7c24d522e6732371f3128 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mtrans_mca.py @@ -0,0 +1,119 @@ +import torch +from torch import nn + +from .mtrans_transformer import Mlp + +# implement by YunluYan + + +class LinearProject(nn.Module): + def __init__(self, dim_in, dim_out, fn): + super(LinearProject, self).__init__() + need_projection = dim_in != dim_out + self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity() + self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity() + self.fn = fn + + def forward(self, x, *args, **kwargs): + x = self.project_in(x) + x = self.fn(x, *args, **kwargs) + x = self.project_out(x) + return x + + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, dim, num_heads, attn_drop=0., proj_drop=0.): + super(MultiHeadCrossAttention, self).__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim, dim*2, bias=False) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + + def forward(self, x, complement): + + # x [B, HW, C] + B_x, N_x, C_x = x.shape + + x_copy = x + + complement = torch.cat([x, complement], 1) + + B_c, N_c, C_c = complement.shape + + # q [B, heads, HW, C//num_heads] + q = self.to_q(x).reshape(B_x, N_x, self.num_heads, C_x//self.num_heads).permute(0, 2, 1, 3) + kv = self.to_kv(complement).reshape(B_c, N_c, 2, self.num_heads, C_c//self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_x, N_x, C_x) + + x = x + x_copy + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossTransformerEncoderLayer(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio = 1., attn_drop=0., proj_drop=0.,drop_path = 0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super(CrossTransformerEncoderLayer, self).__init__() + self.x_norm1 = norm_layer(dim) + self.c_norm1 = norm_layer(dim) + + self.attn = MultiHeadCrossAttention(dim, num_heads, attn_drop, proj_drop) + + self.x_norm2 = norm_layer(dim) + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop) + + self.drop1 = nn.Dropout(drop_path) + self.drop2 = nn.Dropout(drop_path) + + def forward(self, x, complement): + x = self.x_norm1(x) + complement = self.c_norm1(complement) + + x = x + self.drop1(self.attn(x, complement)) + x = x + self.drop2(self.mlp(self.x_norm2(x))) + return x + + + +class CrossTransformer(nn.Module): + def __init__(self, x_dim, c_dim, depth, num_heads, mlp_ratio =1., attn_drop=0., proj_drop=0., drop_path =0.): + super(CrossTransformer, self).__init__() + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + LinearProject(x_dim, c_dim, CrossTransformerEncoderLayer(c_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)), + LinearProject(c_dim, x_dim, CrossTransformerEncoderLayer(x_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)) + ])) + + def forward(self, x, complement): + for x_attn_complemnt, complement_attn_x in self.layers: + x = x_attn_complemnt(x, complement=complement) + x + complement = complement_attn_x(complement, complement=x) + complement + return x, complement + + + + + + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mtrans_net.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mtrans_net.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e80dadb45725603c7ff1da664a45533575db25 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mtrans_net.py @@ -0,0 +1,145 @@ +""" +This script implement the paper "Multi-Modal Transformer for Accelerated MR Imaging (TMI 2022)" +""" + + +import torch + +from torch import nn +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler +from .mtrans_mca import CrossTransformer + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(INPUT_DIM, HEAD_HIDDEN_DIM): + return ReconstructionHead(INPUT_DIM, HEAD_HIDDEN_DIM) + + +class CrossCMMT(nn.Module): + + def __init__(self, args): + super(CrossCMMT, self).__init__() + + INPUT_SIZE = 240 + INPUT_DIM = 1 # the channel of input + OUTPUT_DIM = 1 # the channel of output + HEAD_HIDDEN_DIM = 16 # the hidden dim of Head + TRANSFORMER_DEPTH = 4 # the depth of the transformer + TRANSFORMER_NUM_HEADS = 4 # the head's num of multi head attention + TRANSFORMER_MLP_RATIO = 3 # the MLP RATIO Of transformer + TRANSFORMER_EMBED_DIM = 256 # the EMBED DIM of transformer + P1 = 8 + P2 = 16 + CTDEPTH = 4 + + + self.head = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + self.head2 = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + + x_patch_dim = HEAD_HIDDEN_DIM * P1 ** 2 + x_num_patches = (INPUT_SIZE // P1) ** 2 + + complement_patch_dim = HEAD_HIDDEN_DIM * P2 ** 2 + complement_num_patches = (INPUT_SIZE // P2) ** 2 + + self.x_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P1, + p2=P1), + ) + + self.complement_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P2, + p2=P2), + ) + + self.x_pos_embedding = nn.Parameter(torch.randn(1, x_num_patches, x_patch_dim)) + self.complement_pos_embedding = nn.Parameter(torch.randn(1, complement_num_patches, complement_patch_dim)) + + ### + self.cross_transformer = CrossTransformer(x_patch_dim, complement_patch_dim, CTDEPTH, + TRANSFORMER_NUM_HEADS, TRANSFORMER_MLP_RATIO) + + self.p1 = P1 + self.p2 = P2 + + self.tail1 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + self.tail2 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x, complement): + x = self.head(x) + complement = self.head2(complement) + + x = x + complement + + b, _, h, w = x.shape + + _, _, c_h, c_w = complement.shape + + ### 得到每个模态的patch embedding (加上了position encoding) + x = self.x_patch_embbeding(x) + x += self.x_pos_embedding + + complement = self.complement_patch_embbeding(complement) + complement += self.complement_pos_embedding + + ### 将两个模态的patch embeddings送到cross transformer中提取特征。 + x, complement = self.cross_transformer(x, complement) + + c = int(x.shape[2] / (self.p1 * self.p1)) + H = int(h / self.p1) + W = int(w / self.p1) + + x = x.reshape(b, H, W, self.p1, self.p1, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w, ) + x = self.tail1(x) + + complement = complement.reshape(b, int(c_h/self.p2), int(c_w/self.p2), self.p2, self.p2, int(complement.shape[2]/self.p2/self.p2)) + complement = complement.permute(0, 5, 1, 3, 2, 4) + complement = complement.reshape(b, -1, c_h, c_w) + + complement = self.tail2(complement) + + return x, complement + + +def build_model(args): + return CrossCMMT(args) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mtrans_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mtrans_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..61314a6a81247b0740a1e98d078242b26ce73f90 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/mtrans_transformer.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(args, embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=args.MODEL.TRANSFORMER_DEPTH, + num_heads=args.MODEL.TRANSFORMER_NUM_HEADS, + mlp_ratio=args.MODEL.TRANSFORMER_MLP_RATIO) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/munet_concat_decomp.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/munet_concat_decomp.py new file mode 100644 index 0000000000000000000000000000000000000000..07c6f64804dd586927814801b167163f0b65f58f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/munet_concat_decomp.py @@ -0,0 +1,295 @@ +""" +Xiaohan Xing, 2023/06/25 +两个模态分别用CNN提取多个层级的特征, 每个层级的特征都经过concat之后得到fused feature, +然后每个模态都通过一层conv变换得到2C*h*w的特征, 其中C个channels作为common feature, 另C个channels作为specific features. +利用CDDFuse论文中的decomposition loss约束特征解耦。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.transform_layers1 = nn.ModuleList() + self.transform_layers2 = nn.ModuleList() + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.transform_layers1.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.transform_layers2.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + stack1_common, stack1_specific, stack2_common, stack2_specific = [], [], [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_transform_layer, net2_transform_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.transform_layers1, self.transform_layers2, \ + self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + # print("original feature:", output1.shape) + + ### 分别提取两个模态的common feature和specific feature + output1 = net1_transform_layer(output1) + # print("transformed feature:", output1.shape) + output1_common = output1[:, :(output1.shape[1]//2), :, :] + output1_specific = output1[:, (output1.shape[1]//2):, :, :] + output2 = net2_transform_layer(output2) + output2_common = output2[:, :(output2.shape[1]//2), :, :] + output2_specific = output2[:, (output2.shape[1]//2):, :, :] + + stack1_common.append(output1_common) + stack1_specific.append(output1_specific) + stack2_common.append(output2_common) + stack2_specific.append(output2_specific) + + ### 将一个模态的两组feature和另一模态的common feature合并, 然后经过conv变换得到和初始特征相同维度的fused feature送到下一层。 + output1 = net1_fuse_layer(torch.cat((output1_common, output1_specific, output2_common), 1)) + output2 = net2_fuse_layer(torch.cat((output2_common, output2_specific, output1_common), 1)) + # print("fused feature:", output1.shape) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, stack1_common, stack1_specific, stack2_common, stack2_specific + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/munet_multi_concat.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/munet_multi_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..9e90373e1134210809b98fdc64e3d51d76123855 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/munet_multi_concat.py @@ -0,0 +1,306 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + # output1, output2 = image1, image2 + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # relation1 = get_relation_matrix(output1) + # relation2 = get_relation_matrix(output2) + # relation_stack1.append(relation1) + # relation_stack2.append(relation2) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + # output1 = net1_layer(output1) + # output2 = net2_layer(output2) + + # ### 两个模态的特征concat + # fused_output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + # fused_output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # stack1.append(fused_output1) + # stack2.append(fused_output2) + + # ### downsampling features + # output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + # output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/munet_multi_sum.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/munet_multi_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb4efb13b54ccbeaeb34fc51d3e72b0c50ee242 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/munet_multi_sum.py @@ -0,0 +1,270 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers): + + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + + ## 两个模态的特征求平均. + fused_output1 = (output1 + output2)/2.0 + fused_output2 = (output1 + output2)/2.0 + + output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features #, relation_stack1, relation_stack2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/munet_multi_transfuse.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/munet_multi_transfuse.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ce250a078ba0847c729c1ef1b76452f64d0f07 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/munet_multi_transfuse.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +2023/06/23 +每个层级的特征用TransFuse layer融合之后, 将两个模态的original features和transfuse features concat, +然后在每个模态中经过conv层变换得到下一层的输入。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .TransFuse import TransFuse_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_TransFuse(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + ### conv layer to transform the features after Transfuse layer. + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + self.n_anchors = 15 + self.avgpool = nn.AdaptiveAvgPool2d((self.n_anchors, self.n_anchors)) + self.transfuse_layer1 = TransFuse_layer(n_embd=self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer2 = TransFuse_layer(n_embd=2*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer3 = TransFuse_layer(n_embd=4*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer4 = TransFuse_layer(n_embd=8*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + + self.transfuse_layers = nn.Sequential(self.transfuse_layer1, self.transfuse_layer2, self.transfuse_layer3, self.transfuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, transfuse_layer, net1_fuse_layer, net2_fuse_layer in zip(self.encoder1.down_sample_layers, \ + self.encoder2.down_sample_layers, self.transfuse_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### extract multi-level features from the encoder of each modality. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + ### Transformer-based multi-modal fusion layer + feature1 = self.avgpool(output1) + feature2 = self.avgpool(output2) + feature1, feature2 = transfuse_layer(feature1, feature2) + feature1 = F.interpolate(feature1, scale_factor=output1.shape[-1]//self.n_anchors, mode='bilinear') + feature2 = F.interpolate(feature2, scale_factor=output2.shape[-1]//self.n_anchors, mode='bilinear') + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_TransFuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/munet_swinfusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/munet_swinfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6203c4331744f1da70f18e403360196a419b4cb5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/munet_swinfusion.py @@ -0,0 +1,268 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .SwinFuse_layer import SwinFusion_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_SwinFuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过SwinFusion的方式融合。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.fuse_type = 'swinfuse' + self.img_size = 240 + # self.fuse_type = 'sum' + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # Swin transformer fusion modules + + self.fuse_layer1 = SwinFusion_layer(img_size=(self.img_size//2, self.img_size//2), patch_size=4, window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer2 = SwinFusion_layer(img_size=(self.img_size//4, self.img_size//4), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=2*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer3 = SwinFusion_layer(img_size=(self.img_size//8, self.img_size//8), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=4*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer4 = SwinFusion_layer(img_size=(self.img_size//16, self.img_size//16), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=8*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.swinfusion_layers = nn.Sequential(self.fuse_layer1, self.fuse_layer2, self.fuse_layer3, self.fuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, swinfuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.swinfusion_layers): + output1 = net1_layer(output1) + stack1.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # print("output of encoders:", output1.shape, output2.shape) + # print("CNN feature range:", output1.max(), output2.min()) + + ### Swin transformer-based multi-modal fusion. + feature1, feature2 = swinfuse_layer(output1, output2) + output1 = (output1 + feature1 + output2 + feature2)/4.0 + output2 = (output1 + feature1 + output2 + feature2)/4.0 + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_SwinFuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/original_MINet.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/original_MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..5a115872320f677d8b34264eed95c22fff0df9c4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/original_MINet.py @@ -0,0 +1,335 @@ +from fastmri.models import common +import torch +import torch.nn as nn +import torch.nn.functional as F + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=common.default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 2 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + common.Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.add_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class MINet(nn.Module): + def __init__(self, n_resgroups,n_resblocks, n_feats): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1,x2#x1=pd x2=pdfs + \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..3699cc809e438be4e1bd5efaba7d96e18d7f1602 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/restormer.py @@ -0,0 +1,307 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + # print("feature before FFN projection:", x.max(), x.min()) + x = self.project_out(x) + # print("FFN output:", x.max(), x.min()) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + b,c,h,w = x.shape + # print("Attention block input:", x.max(), x.min()) + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + # print("attention values:", attn) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("Attention block output:", out.max(), out.min()) + + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): + super(TransformerBlock, self).__init__() + + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + # dim = 32, + # num_blocks = [4,4,4,4], + dim = 32, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + + # stack = [] + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + # print("encoder block1 feature:", out_enc_level1.shape, out_enc_level1.max(), out_enc_level1.min()) + # stack.append(out_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + # print("encoder block2 feature:", out_enc_level2.shape, out_enc_level2.max(), out_enc_level2.min()) + # stack.append(out_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + # print("encoder block3 feature:", out_enc_level3.shape, out_enc_level3.max(), out_enc_level3.min()) + # stack.append(out_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + # print("encoder block4 feature:", latent.shape, latent.max(), latent.min()) + # stack.append(latent) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + + return out_dec_level1 #, stack + + + +def build_model(args): + return Restormer() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/restormer_block.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/restormer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..170ae6a826e5a3e356cb48f8c89500c2d5242816 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/restormer_block.py @@ -0,0 +1,321 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + # print("input to the LayerNorm layer:", x.max(), x.min()) + h, w = x.shape[-2:] + x = to_4d(self.body(to_3d(x)), h, w) + # print("output of the LayerNorm:", x.max(), x.min()) + return x + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + # self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.temperature = 1.0 / ((dim//num_heads) ** 0.5) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + # print("input to the Attention layer:", x.max(), x.min()) + + b,c,h,w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + # print("q range:", q.max(), q.min()) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + # print("scaling parameter:", self.temperature) + # print("attention matrix before softmax:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + attn = attn.softmax(dim=-1) + # print("attention matrix range:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + # print("v range:", v.max(), v.min(), v.mean()) + + out = (attn @ v) + # print("self-attention output:", out.max(), out.min()) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("output of attention layer:", out.max(), out.min()) + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, pre_norm): + super(TransformerBlock, self).__init__() + + self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + self.pre_norm = pre_norm + + def forward(self, x): + x = self.conv(x) + # print("restormer block input:", x.max(), x.min()) + if self.pre_norm: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + else: + x = self.norm1(x + self.attn(x)) + x = self.norm2(x + self.ffn(x)) + # x = self.norm1(self.attn(x)) + # x = self.norm2(self.ffn(x)) + # x = x + self.attn(x) + # x = x + self.ffn(x) + # print("restormer block output:", x.max(), x.min()) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + dim = 48, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, inp_img): + + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + + + return out_dec_level1 + + + + +if __name__ == '__main__': + model = Restormer() + # print(model) + + A = torch.randn((1, 1, 240, 240)) + x = model(A) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/swinIR.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/swinIR.py new file mode 100644 index 0000000000000000000000000000000000000000..5cbfc9c175c02531ac80a05ba4abfc0776f752dd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/swinIR.py @@ -0,0 +1,874 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 4, 4, 6], + embed_dim=32, num_heads=[6, 6, 6, 6], mlp_ratio=4, upsampler='pixelshuffledirect') + + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = SwinIR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/swin_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..877bdfbffc91012e4ac31f41b838ebb116f4a825 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/swin_transformer.py @@ -0,0 +1,878 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=1, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + # print("shallow feature:", x.shape) + x = self.conv_after_body(self.forward_features(x)) + x + # print("deep feature:", x.shape) + x = self.upsample(x) + # print("after upsample:", x.shape) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + # return SwinIR(upscale=1, img_size=(240, 240), + # window_size=8, img_range=1., depths=[6, 6, 6, 6], + # embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + +if __name__ == '__main__': + height = 240 + width = 240 + window_size = 8 + model = SwinIR(upscale=1, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + x = torch.randn((1, 1, height, width)) + print("input:", x.shape) + x = model(x) + print("output:", x.shape) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/trans_unet.zip b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/trans_unet.zip new file mode 100644 index 0000000000000000000000000000000000000000..e7b71c1287551fd6119efd7c62da2196307998f5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/trans_unet.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:369de2a28036532c446409ee1f7b953d5763641ffd452675613e1bb9bba89eae +size 19354 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/trans_unet/trans_unet.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/trans_unet/trans_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..fa61dbadfcaee9b26594307adc90cdd1d2a84e3e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/trans_unet/trans_unet.py @@ -0,0 +1,489 @@ +# coding=utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import logging +import math + +from os.path import join as pjoin + +import torch +import torch.nn as nn +import numpy as np + +from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm +from torch.nn.modules.utils import _pair +from scipy import ndimage +from . import vit_seg_configs as configs +# import .vit_seg_configs as configs +from .vit_seg_modeling_resnet_skip import ResNetV2 + + +logger = logging.getLogger(__name__) + + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class Attention(nn.Module): + def __init__(self, config, vis): + super(Attention, self).__init__() + self.vis = vis + self.num_attention_heads = config.transformer["num_heads"] + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(config.hidden_size, self.all_head_size) + self.key = Linear(config.hidden_size, self.all_head_size) + self.value = Linear(config.hidden_size, self.all_head_size) + + self.out = Linear(config.hidden_size, config.hidden_size) + self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) + self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + + +class Mlp(nn.Module): + def __init__(self, config): + super(Mlp, self).__init__() + self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) + self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(config.transformer["dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings. + """ + def __init__(self, config, img_size, in_channels=3): + super(Embeddings, self).__init__() + self.hybrid = None + self.config = config + img_size = _pair(img_size) + + if config.patches.get("grid") is not None: # ResNet + grid_size = config.patches["grid"] + # print("image size:", img_size, "grid size:", grid_size) + patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) + patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) + # print("patch_size_real", patch_size_real) + n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) + self.hybrid = True + else: + patch_size = _pair(config.patches["size"]) + n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) + self.hybrid = False + + if self.hybrid: + self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) + in_channels = self.hybrid_model.width * 16 + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=config.hidden_size, + kernel_size=patch_size, + stride=patch_size) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) + + self.dropout = Dropout(config.transformer["dropout_rate"]) + + + def forward(self, x): + # print("input x:", x.shape) ### (4, 3, 240, 240) + if self.hybrid: + x, features = self.hybrid_model(x) + else: + features = None + # print("output of the hybrid model:", x.shape) ### (4, 1024, 15, 15) + x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) + x = x.flatten(2) + x = x.transpose(-1, -2) # (B, n_patches, hidden) + + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings, features + + +class Block(nn.Module): + def __init__(self, config, vis): + super(Block, self).__init__() + self.hidden_size = config.hidden_size + self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn = Mlp(config) + self.attn = Attention(config, vis) + + def forward(self, x): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + def load_from(self, weights, n_block): + ROOT = f"Transformer/encoderblock_{n_block}" + with torch.no_grad(): + query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() + key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() + value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() + out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() + + query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) + key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) + value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) + out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) + + self.attn.query.weight.copy_(query_weight) + self.attn.key.weight.copy_(key_weight) + self.attn.value.weight.copy_(value_weight) + self.attn.out.weight.copy_(out_weight) + self.attn.query.bias.copy_(query_bias) + self.attn.key.bias.copy_(key_bias) + self.attn.value.bias.copy_(value_bias) + self.attn.out.bias.copy_(out_bias) + + mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() + mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() + mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() + mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() + + self.ffn.fc1.weight.copy_(mlp_weight_0) + self.ffn.fc2.weight.copy_(mlp_weight_1) + self.ffn.fc1.bias.copy_(mlp_bias_0) + self.ffn.fc2.bias.copy_(mlp_bias_1) + + self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) + self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) + self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) + self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) + + +class Encoder(nn.Module): + def __init__(self, config, vis): + super(Encoder, self).__init__() + self.vis = vis + self.layer = nn.ModuleList() + self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) + for _ in range(config.transformer["num_layers"]): + layer = Block(config, vis) + self.layer.append(copy.deepcopy(layer)) + + def forward(self, hidden_states): + attn_weights = [] + for layer_block in self.layer: + hidden_states, weights = layer_block(hidden_states) + if self.vis: + attn_weights.append(weights) + encoded = self.encoder_norm(hidden_states) + return encoded, attn_weights + + +class Transformer(nn.Module): + def __init__(self, config, img_size, vis): + super(Transformer, self).__init__() + self.embeddings = Embeddings(config, img_size=img_size) + self.encoder = Encoder(config, vis) + # print(self.encoder) + + def forward(self, input_ids): + embedding_output, features = self.embeddings(input_ids) ### "features" store the features from different layers. + # print("embedding_output:", embedding_output.shape) ## (B, n_patch, hidden) + encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) + # print("encoded feature:", encoded.shape) + return encoded, attn_weights, features + + # return embedding_output, features + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + bn = nn.BatchNorm2d(out_channels) + + super(Conv2dReLU, self).__init__(conv, bn, relu) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels=0, + use_batchnorm=True, + ): + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + + def forward(self, x, skip=None): + x = self.up(x) + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class ReconstructionHead(nn.Sequential): + + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() + super().__init__(conv2d, upsampling) + + +class DecoderCup(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + head_channels = 512 + self.conv_more = Conv2dReLU( + config.hidden_size, + head_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + decoder_channels = config.decoder_channels + in_channels = [head_channels] + list(decoder_channels[:-1]) + out_channels = decoder_channels + + if self.config.n_skip != 0: + skip_channels = self.config.skip_channels + for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip + skip_channels[3-i]=0 + + else: + skip_channels=[0,0,0,0] + + blocks = [ + DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) + ] + self.blocks = nn.ModuleList(blocks) + # print(self.blocks) + + def forward(self, hidden_states, features=None): + B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) + h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) + x = hidden_states.permute(0, 2, 1) + x = x.contiguous().view(B, hidden, h, w) + x = self.conv_more(x) + for i, decoder_block in enumerate(self.blocks): + if features is not None: + skip = features[i] if (i < self.config.n_skip) else None + else: + skip = None + x = decoder_block(x, skip=skip) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): + super(VisionTransformer, self).__init__() + self.num_classes = num_classes + self.zero_head = zero_head + self.classifier = config.classifier + self.transformer = Transformer(config, img_size, vis) + self.decoder = DecoderCup(config) + self.segmentation_head = ReconstructionHead( + in_channels=config['decoder_channels'][-1], + out_channels=config['n_classes'], + kernel_size=3, + ) + self.activation = nn.Tanh() + self.config = config + + def forward(self, x): + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + ### hybrid CNN and transformer to extract features from the input images. + x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) + + # ### extract features with CNN only. + # x, features = self.transformer(x) # (B, n_patch, hidden) + + x = self.decoder(x, features) + # logits = self.segmentation_head(x) + logits = self.activation(self.segmentation_head(x)) + # print("logits:", logits.shape) + return logits + + def load_from(self, weights): + with torch.no_grad(): + + res_weight = weights + self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) + self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) + + self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) + self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) + + posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) + + posemb_new = self.transformer.embeddings.position_embeddings + if posemb.size() == posemb_new.size(): + self.transformer.embeddings.position_embeddings.copy_(posemb) + elif posemb.size()[1]-1 == posemb_new.size()[1]: + posemb = posemb[:, 1:] + self.transformer.embeddings.position_embeddings.copy_(posemb) + else: + logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) + ntok_new = posemb_new.size(1) + if self.classifier == "seg": + _, posemb_grid = posemb[:, :1], posemb[0, 1:] + gs_old = int(np.sqrt(len(posemb_grid))) + gs_new = int(np.sqrt(ntok_new)) + print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb = posemb_grid + self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) + + # Encoder whole + for bname, block in self.transformer.encoder.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=uname) + + if self.transformer.embeddings.hybrid: + self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) + gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) + gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) + self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) + self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) + + for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(res_weight, n_block=bname, n_unit=uname) + +CONFIGS = { + 'ViT-B_16': configs.get_b16_config(), + 'ViT-B_32': configs.get_b32_config(), + 'ViT-L_16': configs.get_l16_config(), + 'ViT-L_32': configs.get_l32_config(), + 'ViT-H_14': configs.get_h14_config(), + 'R50-ViT-B_16': configs.get_r50_b16_config(), + 'R50-ViT-L_16': configs.get_r50_l16_config(), + 'testing': configs.get_testing(), +} + + + +def build_model(args): + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + print(net) + # net.load_from(weights=np.load(config_vit.pretrained_path)) + return net + + +if __name__ == "__main__": + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + net.load_from(weights=np.load(config_vit.pretrained_path)) + + input_data = torch.rand(4, 1, 240, 240).cuda() + print("input:", input_data.shape) + output = net(input_data) + print("output:", output.shape) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/trans_unet/vit_seg_configs.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/trans_unet/vit_seg_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..bed7d3346e30328a38ac6fef1621f788d905df1e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/trans_unet/vit_seg_configs.py @@ -0,0 +1,132 @@ +import ml_collections + +def get_b16_config(): + """Returns the ViT-B/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 768 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 3072 + config.transformer.num_heads = 12 + config.transformer.num_layers = 4 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + + config.classifier = 'seg' + config.representation_size = None + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' + config.patch_size = 16 + + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_testing(): + """Returns a minimal configuration for testing.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 1 + config.transformer.num_heads = 1 + config.transformer.num_layers = 1 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config + + +def get_r50_b16_config(): + """Returns the Resnet50 + ViT-B/16 configuration.""" + config = get_b16_config() + # config.patches.grid = (16, 16) + config.patches.grid = (15, 15) ### for the input image with size (240, 240) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'recon' + config.pretrained_path = '/home/xiaohan/workspace/MSL_MRI/code/pretrained_model/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 1 + config.n_skip = 3 + config.activation = 'tanh' + + return config + + +def get_b32_config(): + """Returns the ViT-B/32 configuration.""" + config = get_b16_config() + config.patches.size = (32, 32) + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' + return config + + +def get_l16_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1024 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 4096 + config.transformer.num_heads = 16 + config.transformer.num_layers = 24 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.representation_size = None + + # custom + config.classifier = 'seg' + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_r50_l16_config(): + """Returns the Resnet50 + ViT-L/16 configuration. customized """ + config = get_l16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_l32_config(): + """Returns the ViT-L/32 configuration.""" + config = get_l16_config() + config.patches.size = (32, 32) + return config + + +def get_h14_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (14, 14)}) + config.hidden_size = 1280 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 5120 + config.transformer.num_heads = 16 + config.transformer.num_layers = 32 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + + return config \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae80ef7d96c46cb91bda849901aef34008e62ed --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py @@ -0,0 +1,160 @@ +import math + +from os.path import join as pjoin +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +class StdConv2d(nn.Conv2d): + + def forward(self, x): + w = self.weight + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / torch.sqrt(v + 1e-5) + return F.conv2d(x, w, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +def conv3x3(cin, cout, stride=1, groups=1, bias=False): + return StdConv2d(cin, cout, kernel_size=3, stride=stride, + padding=1, bias=bias, groups=groups) + + +def conv1x1(cin, cout, stride=1, bias=False): + return StdConv2d(cin, cout, kernel_size=1, stride=stride, + padding=0, bias=bias) + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + """ + + def __init__(self, cin, cout=None, cmid=None, stride=1): + super().__init__() + cout = cout or cin + cmid = cmid or cout//4 + + self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv1 = conv1x1(cin, cmid, bias=False) + self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! + self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) + self.conv3 = conv1x1(cmid, cout, bias=False) + self.relu = nn.ReLU(inplace=True) + + if (stride != 1 or cin != cout): + # Projection also with pre-activation according to paper. + self.downsample = conv1x1(cin, cout, stride, bias=False) + self.gn_proj = nn.GroupNorm(cout, cout) + + def forward(self, x): + + # Residual branch + residual = x + if hasattr(self, 'downsample'): + residual = self.downsample(x) + residual = self.gn_proj(residual) + + # Unit's branch + y = self.relu(self.gn1(self.conv1(x))) + y = self.relu(self.gn2(self.conv2(y))) + y = self.gn3(self.conv3(y)) + + y = self.relu(residual + y) + return y + + def load_from(self, weights, n_block, n_unit): + conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) + conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) + conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) + + gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) + gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) + + gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) + gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) + + gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) + gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) + + self.conv1.weight.copy_(conv1_weight) + self.conv2.weight.copy_(conv2_weight) + self.conv3.weight.copy_(conv3_weight) + + self.gn1.weight.copy_(gn1_weight.view(-1)) + self.gn1.bias.copy_(gn1_bias.view(-1)) + + self.gn2.weight.copy_(gn2_weight.view(-1)) + self.gn2.bias.copy_(gn2_bias.view(-1)) + + self.gn3.weight.copy_(gn3_weight.view(-1)) + self.gn3.bias.copy_(gn3_bias.view(-1)) + + if hasattr(self, 'downsample'): + proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) + proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) + proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) + + self.downsample.weight.copy_(proj_conv_weight) + self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) + self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode.""" + + def __init__(self, block_units, width_factor): + super().__init__() + width = int(64 * width_factor) + self.width = width + + self.root = nn.Sequential(OrderedDict([ + ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), + ('gn', nn.GroupNorm(32, width, eps=1e-6)), + ('relu', nn.ReLU(inplace=True)), + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) + ])) + + self.body = nn.Sequential(OrderedDict([ + ('block1', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], + ))), + ('block2', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], + ))), + ('block3', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], + ))), + ])) + + def forward(self, x): + features = [] + b, c, in_size, _ = x.size() + x = self.root(x) + features.append(x) + x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) + for i in range(len(self.body)-1): + x = self.body[i](x) + right_size = int(in_size / 4 / (i+1)) + if x.size()[2] != right_size: + pad = right_size - x.size()[2] + assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) + feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) + feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] + else: + feat = x + features.append(feat) + x = self.body[-1](x) + return x, features[::-1] \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/transformer_modules.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/transformer_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..900042884305619e39ad9b792b3b1936dfbd37b3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/transformer_modules.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=4, + num_heads=4, + mlp_ratio=3) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/unet.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..d6dd935cd34a121d8079a8f5184d4ea29e15c8da --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/unet.py @@ -0,0 +1,196 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F + + +class Unet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = image + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + # print("unet feature range:", output.max(), output.min()) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/unet_restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/unet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f78aa5f192957b2706c6f691dfc66c427b65ae --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/unet_restormer.py @@ -0,0 +1,234 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + # print("Unet feature range:", output.max(), output.min()) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/unet_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/unet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..982fe23cec321de6b636e04c51b62037054ea432 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/unet_transformer.py @@ -0,0 +1,346 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +对于Unet中的high-level feature, 用Transformer进行特征交互,然后和Transformer之前的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Unet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = image + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack, output = self.encoder(output) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output) # [4, 225, 256], 225 is the number of patches. + patch_embed_input = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1])), int(np.sqrt(feature_output.shape[1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[-1], h, w) # [4,512*2,24,24] + + output = torch.cat((output, feature_output), 1) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output = self.decoder(output, stack) + + return output + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Transformer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/unimodal_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/unimodal_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c99214e10bee2f7a1be49f8560799ffbc14ed0a5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/compare_models/unimodal_transformer.py @@ -0,0 +1,112 @@ +import torch + +from torch import nn +from .transformer_modules import build_transformer +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(): + return ReconstructionHead(input_dim=1, hidden_dim=12) + + + +class CMMT(nn.Module): + + def __init__(self, args): + super(CMMT, self).__init__() + + self.head = build_head() + + HEAD_HIDDEN_DIM = 12 + PATCH_SIZE = 16 + INPUT_SIZE = 240 + OUTPUT_DIM = 1 + + patch_dim = HEAD_HIDDEN_DIM* PATCH_SIZE ** 2 + num_patches = (INPUT_SIZE // PATCH_SIZE) ** 2 + + self.transformer = build_transformer(patch_dim) + + self.patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=PATCH_SIZE, p2=PATCH_SIZE), + ) + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, patch_dim)) + + + self.p1 = PATCH_SIZE + self.p2 = PATCH_SIZE + + self.tail = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x): + + b,_ , h, w = x.shape + + x = self.head(x) + + x= self.patch_embbeding(x) + + x += self.pos_embedding + + x = self.transformer(x) # b HW p1p2c + + c = int(x.shape[2]/(self.p1*self.p2)) + H = int(h/self.p1) + W = int(w/self.p2) + + x = x.reshape(b, H, W, self.p1, self.p2, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w,) + x = self.tail(x) + + return x + + +def build_model(args): + return CMMT(args) + + + + + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/modules.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..bad92a6d6709130c4ad1cc49d5094958d44d7e71 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/modules.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + +USE_PYTORCH_IN = False + + +###################################################################### +# Superclass of all Modules that take two inputs +###################################################################### +class TwoInputModule(nn.Module): + def forward(self, input1, input2): + raise NotImplementedError + +###################################################################### +# A (sort of) hacky way to create a module that takes two inputs (e.g. x and z) +# and returns one output (say o) defined as follows: +# o = module2.forward(module1.forward(x), z) +# Note that module2 MUST support two inputs as well. +###################################################################### +class MergeModule(TwoInputModule): + def __init__(self, module1, module2): + """ module1 could be any module (e.g. Sequential of several modules) + module2 must accept two inputs + """ + super(MergeModule, self).__init__() + self.module1 = module1 + self.module2 = module2 + + def forward(self, input1, input2): + output1 = self.module1.forward(input1) + output2 = self.module2.forward(output1, input2) + return output2 + +###################################################################### +# A (sort of) hacky way to create a container that takes two inputs (e.g. x and z) +# and applies a sequence of modules (exactly like nn.Sequential) but MergeModule +# is one of its submodules it applies it to both inputs +###################################################################### +class TwoInputSequential(nn.Sequential, TwoInputModule): + def __init__(self, *args): + super(TwoInputSequential, self).__init__(*args) + + def forward(self, input1, input2): + """overloads forward function in parent calss""" + + for module in self._modules.values(): + if isinstance(module, TwoInputModule): + input1 = module.forward(input1, input2) + else: + input1 = module.forward(input1) + return input1 + + +###################################################################### +# A standard instance norm module. +# Since the pytorch instance norm used BatchNorm as a base and thus is +# different from the standard implementation. +###################################################################### +class InstanceNorm(nn.Module): + def __init__(self, num_features, affine=True, eps=1e-5): + """`num_features` number of feature channels + """ + super(InstanceNorm, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + + self.scale = Parameter(torch.Tensor(num_features)) + self.shift = Parameter(torch.Tensor(num_features)) + + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + self.scale.data.normal_(mean=0., std=0.02) + self.shift.data.zero_() + + def forward(self, input): + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + centered_x = x_reshaped - mean + std = torch.rsqrt((centered_x ** 2).mean(2, keepdim=True) + self.eps) + norm_features = (centered_x * std).view(*size) + + # broadcast on the batch dimension, hight and width dimensions + if self.affine: + output = norm_features * self.scale[:,None,None] + self.shift[:,None,None] + else: + output = norm_features + + return output + +InstanceNorm2d = nn.InstanceNorm2d if USE_PYTORCH_IN else InstanceNorm + +###################################################################### +# A module implementing conditional instance norm. +# Takes two inputs: x (input features) and z (latent codes) +###################################################################### +class CondInstanceNorm(TwoInputModule): + def __init__(self, x_dim, z_dim, eps=1e-5): + """`x_dim` dimensionality of x input + `z_dim` dimensionality of z latents + """ + super(CondInstanceNorm, self).__init__() + self.eps = eps + self.shift_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + self.scale_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + + def forward(self, input, noise): + + shift = self.shift_conv.forward(noise) + scale = self.scale_conv.forward(noise) + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + var = x_reshaped.var(2, keepdim=True) + std = torch.rsqrt(var + self.eps) + norm_features = ((x_reshaped - mean) * std).view(*size) + output = norm_features * scale + shift + return output + + +###################################################################### +# A modified resnet block which allows for passing additional noise input +# to be used for conditional instance norm +###################################################################### +class CINResnetBlock(TwoInputModule): + def __init__(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + super(CINResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + for idx, module in enumerate(self.conv_block): + self.add_module(str(idx), module) + + def build_conv_block(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [ + MergeModule( + nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(x_dim, z_dim) + ), + nn.ReLU(True) + ] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + InstanceNorm2d(x_dim, affine=True)] + + return TwoInputSequential(*conv_block) + + def forward(self, x, noise): + out = self.conv_block(x, noise) + out = self.relu(x + out) + return out + +###################################################################### +# Define a resnet block +###################################################################### +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [nn.ReLU(True)] + + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = self.conv_block(x) + out = self.relu(x + out) + return out diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/mynet.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/mynet.py new file mode 100644 index 0000000000000000000000000000000000000000..a3e93c185c773070d07437777b9c01ff11824d4b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/networks/mynet.py @@ -0,0 +1,388 @@ +import torch +from torch import nn +from networks import common_freq as common + + +class TwoBranch(nn.Module): + def __init__(self, args): + super(TwoBranch, self).__init__() + + num_group = 4 + num_every_group = args.base_num_every_group + self.args = args + + self.init_T2_frq_branch(args) + self.init_T2_spa_branch(args, num_every_group) + self.init_T2_fre_spa_fusion(args) + + self.init_T1_frq_branch(args) + self.init_T1_spa_branch(args, num_every_group) + + self.init_modality_fre_fusion(args) + self.init_modality_spa_fusion(args) + + + def init_T2_frq_branch(self, args): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head_fre = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + + self.down1_fre = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down2_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down2_fre = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down3_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down3_fre = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_neck_fre = [common.FreBlock9(args.num_features, args) + ] + self.neck_fre = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_up1_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up1_fre = nn.Sequential(*modules_up1_fre) + self.up1_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_up2_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up2_fre = nn.Sequential(*modules_up2_fre) + self.up2_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_up3_fre = [common.UpSampler(2, args.num_features), + common.FreBlock9(args.num_features, args) + ] + self.up3_fre = nn.Sequential(*modules_up3_fre) + self.up3_fre_mo = nn.Sequential(common.FreBlock9(args.num_features, args)) + + # define tail module + modules_tail_fre = [ + common.ConvBNReLU2D(args.num_features, out_channels=args.num_channels, kernel_size=3, padding=1, + act=args.act)] + self.tail_fre = nn.Sequential(*modules_tail_fre) + + def init_T2_spa_branch(self, args, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down1 = nn.Sequential(*modules_down1) + + + self.down1_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down2 = nn.Sequential(*modules_down2) + + self.down2_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down3 = nn.Sequential(*modules_down3) + self.down3_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.neck = nn.Sequential(*modules_neck) + + self.neck_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_up1 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.up1 = nn.Sequential(*modules_up1) + + self.up1_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_up2 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.up2 = nn.Sequential(*modules_up2) + self.up2_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + + modules_up3 = [common.UpSampler(2, args.num_features), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.up3 = nn.Sequential(*modules_up3) + self.up3_mo = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + # define tail module + modules_tail = [ + common.ConvBNReLU2D(args.num_features, out_channels=args.num_channels, kernel_size=3, padding=1, + act=args.act)] + + self.tail = nn.Sequential(*modules_tail) + + def init_T2_fre_spa_fusion(self, args): + ### T2 frq & spa fusion part + conv_fuse = [] + for i in range(14): + conv_fuse.append(common.FuseBlock7(args.num_features)) + self.conv_fuse = nn.Sequential(*conv_fuse) + + def init_T1_frq_branch(self, args): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head_fre_T1 = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + + self.down1_fre_T1 = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down2_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down2_fre_T1 = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_down3_fre = [common.DownSample(args.num_features, False, False), + common.FreBlock9(args.num_features, args) + ] + self.down3_fre_T1 = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + modules_neck_fre = [common.FreBlock9(args.num_features, args) + ] + self.neck_fre_T1 = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo_T1 = nn.Sequential(common.FreBlock9(args.num_features, args)) + + def init_T1_spa_branch(self, args, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=args.num_features, + kernel_size=3, padding=1, act=args.act)] + self.head_T1 = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down1_T1 = nn.Sequential(*modules_down1) + + + self.down1_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down2_T1 = nn.Sequential(*modules_down2) + + self.down2_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(args.num_features, False, False), + common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.down3_T1 = nn.Sequential(*modules_down3) + self.down3_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None) + ] + self.neck_T1 = nn.Sequential(*modules_neck) + + self.neck_mo_T1 = nn.Sequential(common.ResidualGroup( + args.num_features, 3, 4, act=args.act, n_resblocks=num_every_group, norm=None)) + + + def init_modality_fre_fusion(self, args): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(args.num_features)) + self.conv_fuse_fre = nn.Sequential(*conv_fuse) + + def init_modality_spa_fusion(self, args): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(args.num_features)) + self.conv_fuse_spa = nn.Sequential(*conv_fuse) + + def forward(self, main, aux): + #### T1 fre encoder + t1_fre = self.head_fre_T1(aux) # 128 + + down1_fre_t1 = self.down1_fre_T1(t1_fre)# 64 + down1_fre_mo_t1 = self.down1_fre_mo_T1(down1_fre_t1) + + down2_fre_t1 = self.down2_fre_T1(down1_fre_mo_t1) # 32 + down2_fre_mo_t1 = self.down2_fre_mo_T1(down2_fre_t1) + + down3_fre_t1 = self.down3_fre_T1(down2_fre_mo_t1) # 16 + down3_fre_mo_t1 = self.down3_fre_mo_T1(down3_fre_t1) + + neck_fre_t1 = self.neck_fre_T1(down3_fre_mo_t1) # 16 + neck_fre_mo_t1 = self.neck_fre_mo_T1(neck_fre_t1) + + + #### T2 fre encoder and T1 & T2 fre fusion + x_fre = self.head_fre(main) # 128 + x_fre_fuse = self.conv_fuse_fre[0](t1_fre, x_fre) + + down1_fre = self.down1_fre(x_fre_fuse)# 64 + down1_fre_mo = self.down1_fre_mo(down1_fre) + down1_fre_mo_fuse = self.conv_fuse_fre[1](down1_fre_mo_t1, down1_fre_mo) + + down2_fre = self.down2_fre(down1_fre_mo_fuse) # 32 + down2_fre_mo = self.down2_fre_mo(down2_fre) + down2_fre_mo_fuse = self.conv_fuse_fre[2](down2_fre_mo_t1, down2_fre_mo) + + down3_fre = self.down3_fre(down2_fre_mo_fuse) # 16 + down3_fre_mo = self.down3_fre_mo(down3_fre) + down3_fre_mo_fuse = self.conv_fuse_fre[3](down3_fre_mo_t1, down3_fre_mo) + + neck_fre = self.neck_fre(down3_fre_mo_fuse) # 16 + neck_fre_mo = self.neck_fre_mo(neck_fre) + neck_fre_mo_fuse = self.conv_fuse_fre[4](neck_fre_mo_t1, neck_fre_mo) + + + #### T2 fre decoder + neck_fre_mo = neck_fre_mo_fuse + down3_fre_mo_fuse + + up1_fre = self.up1_fre(neck_fre_mo) # 32 + up1_fre_mo = self.up1_fre_mo(up1_fre) + up1_fre_mo = up1_fre_mo + down2_fre_mo_fuse + + up2_fre = self.up2_fre(up1_fre_mo) # 64 + up2_fre_mo = self.up2_fre_mo(up2_fre) + up2_fre_mo = up2_fre_mo + down1_fre_mo_fuse + + up3_fre = self.up3_fre(up2_fre_mo) # 128 + up3_fre_mo = self.up3_fre_mo(up3_fre) + up3_fre_mo = up3_fre_mo + x_fre_fuse + + res_fre = self.tail_fre(up3_fre_mo) + + #### T1 spa encoder + x_t1 = self.head_T1(aux) # 128 + + down1_t1 = self.down1_T1(x_t1) # 64 + down1_mo_t1 = self.down1_mo_T1(down1_t1) + + down2_t1 = self.down2_T1(down1_mo_t1) # 32 + down2_mo_t1 = self.down2_mo_T1(down2_t1) # 32 + + down3_t1 = self.down3_T1(down2_mo_t1) # 16 + down3_mo_t1 = self.down3_mo_T1(down3_t1) # 16 + + neck_t1 = self.neck_T1(down3_mo_t1) # 16 + neck_mo_t1 = self.neck_mo_T1(neck_t1) + + #### T2 spa encoder and fusion + x = self.head(main) # 128 + + x_fuse = self.conv_fuse_spa[0](x_t1, x) + down1 = self.down1(x_fuse) # 64 + down1_fuse = self.conv_fuse[0](down1_fre, down1) + down1_mo = self.down1_mo(down1_fuse) + down1_fuse_mo = self.conv_fuse[1](down1_fre_mo_fuse, down1_mo) + + down1_fuse_mo_fuse = self.conv_fuse_spa[1](down1_mo_t1, down1_fuse_mo) + down2 = self.down2(down1_fuse_mo_fuse) # 32 + down2_fuse = self.conv_fuse[2](down2_fre, down2) + down2_mo = self.down2_mo(down2_fuse) # 32 + down2_fuse_mo = self.conv_fuse[3](down2_fre_mo, down2_mo) + + down2_fuse_mo_fuse = self.conv_fuse_spa[2](down2_mo_t1, down2_fuse_mo) + down3 = self.down3(down2_fuse_mo_fuse) # 16 + down3_fuse = self.conv_fuse[4](down3_fre, down3) + down3_mo = self.down3_mo(down3_fuse) # 16 + down3_fuse_mo = self.conv_fuse[5](down3_fre_mo, down3_mo) + + down3_fuse_mo_fuse = self.conv_fuse_spa[3](down3_mo_t1, down3_fuse_mo) + neck = self.neck(down3_fuse_mo_fuse) # 16 + neck_fuse = self.conv_fuse[6](neck_fre, neck) + neck_mo = self.neck_mo(neck_fuse) + neck_mo = neck_mo + down3_mo + neck_fuse_mo = self.conv_fuse[7](neck_fre_mo, neck_mo) + + neck_fuse_mo_fuse = self.conv_fuse_spa[4](neck_mo_t1, neck_fuse_mo) + #### T2 spa decoder + up1 = self.up1(neck_fuse_mo_fuse) # 32 + up1_fuse = self.conv_fuse[8](up1_fre, up1) + up1_mo = self.up1_mo(up1_fuse) + up1_mo = up1_mo + down2_mo + up1_fuse_mo = self.conv_fuse[9](up1_fre_mo, up1_mo) + + up2 = self.up2(up1_fuse_mo) # 64 + up2_fuse = self.conv_fuse[10](up2_fre, up2) + up2_mo = self.up2_mo(up2_fuse) + up2_mo = up2_mo + down1_mo + up2_fuse_mo = self.conv_fuse[11](up2_fre_mo, up2_mo) + + up3 = self.up3(up2_fuse_mo) # 128 + + up3_fuse = self.conv_fuse[12](up3_fre, up3) + up3_mo = self.up3_mo(up3_fuse) + + up3_mo = up3_mo + x + up3_fuse_mo = self.conv_fuse[13](up3_fre_mo, up3_mo) + # import matplotlib.pyplot as plt + # plt.axis('off') + # plt.imshow((255*up3_fre_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_fre_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + + # plt.axis('off') + # plt.imshow((255*up3_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + + # plt.axis('off') + # plt.imshow((255*up3_fuse_mo[0].detach().cpu().numpy()[0])) + # plt.savefig('up3_fuse_mo.jpg', bbox_inches='tight', pad_inches=0) + # plt.clf() + # breakpoint() + + res = self.tail(up3_fuse_mo) + + return {'img_out': res + main, 'img_fre': res_fre + main} + +def make_model(args): + return TwoBranch(args) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/option.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/option.py new file mode 100644 index 0000000000000000000000000000000000000000..f6822c0797cdf1191be2a2c6c16842d65d3b8138 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/option.py @@ -0,0 +1,62 @@ +import argparse + +parser = argparse.ArgumentParser(description='MRI recon') +parser.add_argument('--root_path', type=str, default='/home/xiaohan/datasets/BRATS_dataset/BRATS_2020_images/selected_images/') +parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') +parser.add_argument('--low_field_SNR', type=int, default=15, help='SNR of the simulated low-field image') +parser.add_argument('--phase', type=str, default='train', help='Name of phase') +parser.add_argument('--gpu', type=str, default='0', help='GPU to use') +parser.add_argument('--exp', type=str, default='msl_model', help='model_name') +parser.add_argument('--max_iterations', type=int, default=100000, help='maximum epoch number to train') +parser.add_argument('--batch_size', type=int, default=8, help='batch_size per gpu') +parser.add_argument('--base_lr', type=float, default=0.0002, help='maximum epoch numaber to train') +parser.add_argument('--seed', type=int, default=1337, help='random seed') +parser.add_argument('--resume', type=str, default=None, help='resume') +parser.add_argument('--relation_consistency', type=str, default='False', help='regularize the consistency of feature relation') +parser.add_argument('--clip_grad', type=str, default='True', help='clip gradient of the network parameters') + + +parser.add_argument('--norm', type=str, default='False', help='Norm Layer between UNet and Transformer') +parser.add_argument('--input_normalize', type=str, default='mean_std', help='choose from [min_max, mean_std, divide]') + +parser.add_argument("--dist_url", default="63654") + +parser.add_argument('--scale', type=int, default=8, + help='super resolution scale') +parser.add_argument('--base_num_every_group', type=int, default=2, + help='super resolution scale') + + +parser.add_argument('--rgb_range', type=int, default=255, + help='maximum value of RGB') +parser.add_argument('--n_colors', type=int, default=3, + help='number of color channels to use') +parser.add_argument('--augment', action='store_true', + help='use data augmentation') +parser.add_argument('--fftloss', action='store_true', + help='use data augmentation') +parser.add_argument('--fftd', action='store_true', + help='use data augmentation') +parser.add_argument('--fftd_weight', type=float, default=0.1, + help='use data augmentation') +parser.add_argument('--fft_weight', type=float, default=0.01) + +# Model specifications +parser.add_argument('--model', type=str, default='MYNET') +parser.add_argument('--act', type=str, default='PReLU') +parser.add_argument('--data_range', type=float, default=1) +parser.add_argument('--num_channels', type=int, default=1) +parser.add_argument('--num_features', type=int, default=64) + +parser.add_argument('--n_feats', type=int, default=64, + help='number of feature maps') +parser.add_argument('--res_scale', type=float, default=0.2, + help='residual scaling') + +parser.add_argument('--MASKTYPE', type=str, default='random') # "random" or "equispaced" +parser.add_argument('--CENTER_FRACTIONS', nargs='+', type=float) +parser.add_argument('--ACCELERATIONS', nargs='+', type=int) + + + +args = parser.parse_args() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/test_brats.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/test_brats.py new file mode 100644 index 0000000000000000000000000000000000000000..f371d55781cb361124387c7c651d5b133a2f5600 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/test_brats.py @@ -0,0 +1,150 @@ +import os +import sys +from tqdm import tqdm +import shutil +import argparse +import logging +import numpy as np +from skimage import io +from scipy.ndimage import zoom + +import torch +import torch.nn as nn +from torchvision import transforms +from torch.utils.data import DataLoader +from networks.compare_models import build_model_from_name +from dataloaders.BRATS_dataloader_new import Hybrid as MyDataset +from dataloaders.BRATS_dataloader_new import ToTensor +from networks.mynet import TwoBranch +from utils import bright, trunc +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + + +parser = argparse.ArgumentParser() +parser.add_argument('--root_path', type=str, default='/home/xiaohan/datasets/BRATS_dataset/BRATS_2020_images/selected_images/') +parser.add_argument('--MRIDOWN', type=str, default='4X', help='MRI down-sampling rate') +parser.add_argument('--low_field_SNR', type=int, default=15, help='SNR of the simulated low-field image') +parser.add_argument('--phase', type=str, default='test', help='Name of phase') +parser.add_argument('--gpu', type=str, default='0', help='GPU to use') +parser.add_argument('--exp', type=str, default='msl_model', help='model_name') +parser.add_argument('--seed', type=int, default=1337, help='random seed') +parser.add_argument('--base_lr', type=float, default=0.0002, help='maximum epoch numaber to train') + +parser.add_argument('--model_name', type=str, default='unet_single', help='model_name') +parser.add_argument('--relation_consistency', type=str, default='False', help='regularize the consistency of feature relation') +parser.add_argument('--norm', type=str, default='False', help='Norm Layer between UNet and Transformer') +parser.add_argument('--input_normalize', type=str, default='mean_std', help='choose from [min_max, mean_std, divide]') + +# args = parser.parse_args() +from option import args +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + + +def normalize_output(out_img): + out_img = (out_img - out_img.min())/(out_img.max() - out_img.min() + 1e-8) + return out_img + + +if __name__ == "__main__": + ## make logger file + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + network = TwoBranch(args).cuda() + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + + db_test = MyDataset(split='test', MRIDOWN=args.MRIDOWN, SNR=args.low_field_SNR, + transform=transforms.Compose([ToTensor()]), + base_dir=test_data_path, input_normalize = args.input_normalize) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=2, pin_memory=True) + + if args.phase == 'test': + + save_mode_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + print('load weights from ' + save_mode_path) + checkpoint = torch.load(save_mode_path) + network.load_state_dict(checkpoint['network']) + network.eval() + cnt = 0 + save_path = snapshot_path + '/result_case/' + feature_save_path = snapshot_path + '/feature_visualization/' + if not os.path.exists(save_path): + os.makedirs(save_path) + if not os.path.exists(feature_save_path): + os.makedirs(feature_save_path) + + + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all = [], [], [] + + for (sampled_batch, sample_stats) in tqdm(testloader, ncols=70): + cnt += 1 + + print('processing ' + str(cnt) + ' image') + t1_in, t1, t2_in, t2 = sampled_batch['image_in'].cuda(), sampled_batch['image'].cuda(), \ + sampled_batch['target_in'].cuda(), sampled_batch['target'].cuda() + t1_krecon, t2_krecon = sampled_batch['image_krecon'].cuda(), sampled_batch['target_krecon'].cuda() + + t2_mean = sample_stats['t2_mean'].data.cpu().numpy()[0] + t2_std = sample_stats['t2_std'].data.cpu().numpy()[0] + + + t1_out, t2_out = None, None + + + t2_out = network(t2_in, t1_in)['img_out'] + t2_out_2 = network(t2_in, t1_in)['img_out'] + + t1_mean = sample_stats['t1_mean'].data.cpu().numpy()[0] + t1_std = sample_stats['t1_std'].data.cpu().numpy()[0] + t2_mean = sample_stats['t2_mean'].data.cpu().numpy()[0] + t2_std = sample_stats['t2_std'].data.cpu().numpy()[0] + + if t1_out is not None: + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_out_img = (np.clip(t1_out.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_krecon_img = (np.clip(t1_krecon.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t2_in_img = (np.clip(t2_in.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_img = (np.clip(t2.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_out_img = (np.clip(t2_out.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_krecon_img = (np.clip(t2_krecon.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_out_2_img = (np.clip(t2_out_2.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + + io.imsave(save_path + str(cnt) + '_t1.png', bright(t1_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2.png', bright(t2_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2_in.png', bright(t2_in_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2_out.png', bright(t2_out_img,0,0.8)) + io.imsave(save_path + str(cnt) + '_t2_out2.png', bright(t2_out_2_img,0,0.8)) + + + if t2_out is not None: + t2_out_img[t2_out_img < 0.0] = 0.0 + t2_img[t2_img < 0.0] = 0.0 + MSE = mean_squared_error(t2_img, t2_out_img) + PSNR = peak_signal_noise_ratio(t2_img, t2_out_img) + SSIM = structural_similarity(t2_img, t2_out_img) + t2_MSE_all.append(MSE) + t2_PSNR_all.append(PSNR) + t2_SSIM_all.append(SSIM) + print("[t2 MRI] MSE:", MSE, "PSNR:", PSNR, "SSIM:", SSIM) + + # if cnt > 20: + # break + + print("[T2 MRI:] average MSE:", np.array(t2_MSE_all).mean(), "average PSNR:", np.array(t2_PSNR_all).mean(), "average SSIM:", np.array(t2_SSIM_all).mean()) + print("[T2 MRI:] average MSE:", np.array(t2_MSE_all).std(), "average PSNR:", np.array(t2_PSNR_all).std(), "average SSIM:", np.array(t2_SSIM_all).std()) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/test_fastmri.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/test_fastmri.py new file mode 100644 index 0000000000000000000000000000000000000000..10153dd8b437436019f5217abeaf13da54fc4e37 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/test_fastmri.py @@ -0,0 +1,168 @@ +import os +import sys +from tqdm import tqdm +import shutil +import argparse +import logging +import numpy as np +from skimage import io +from scipy.ndimage import zoom + +import torch +import torch.nn as nn +from torchvision import transforms +from torch.utils.data import DataLoader +from networks.compare_models import build_model_from_name +from dataloaders.BRATS_dataloader_new import Hybrid as MyDataset +from dataloaders.BRATS_dataloader_new import ToTensor +from networks.mynet import TwoBranch +from utils import bright, trunc +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + +from option import args + +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + + +def normalize_output(out_img): + out_img = (out_img - out_img.min())/(out_img.max() - out_img.min() + 1e-8) + return out_img + +from metric import nmse, psnr, ssim, AverageMeter +from collections import defaultdict +@torch.no_grad() +def evaluate(model, data_loader, device, save_path): + os.makedirs(save_path, exist_ok=True) + model.eval() + nmse_meter = AverageMeter() + psnr_meter = AverageMeter() + ssim_meter = AverageMeter() + output_dic = defaultdict(dict) + target_dic = defaultdict(dict) + input_dic = defaultdict(dict) + flag=0 + last_name='no' + for data in data_loader: + pd, pdfs, _ = data + name = os.path.basename(pdfs[4][0]).split('.')[0] + if not last_name == name: + last_name = name + flag+=1 + if flag < 3: + continue + elif flag >= 4: + break + else: + pass + + target = pdfs[1] + + mean = pdfs[2] + std = pdfs[3] + + fname = pdfs[4] + slice_num = pdfs[5] + + mean = mean.unsqueeze(1).unsqueeze(2) + std = std.unsqueeze(1).unsqueeze(2) + + mean = mean.to(device) + std = std.to(device) + + pd_img = pd[1].unsqueeze(1) + pdfs_img = pdfs[0].unsqueeze(1) + + pd_img = pd_img.to(device) + pdfs_img = pdfs_img.to(device) + target = target.to(device) + + outputs = network(pdfs_img, pd_img)['img_out'] + outputs = outputs.squeeze(1) + + outputs_save = outputs[0].cpu().numpy()/6.0 + outputs_save = np.clip(outputs_save, a_min=-1, a_max=1) + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '.png', target[0].cpu().numpy()/6.0) + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '_in.png', pdfs_img[0][0].cpu().numpy()/6.0) + io.imsave(save_path + str(name) + '_' + str(slice_num[0].cpu().numpy()) + '_out.png', outputs_save) + + outputs = outputs * std + mean + target = target * std + mean + inputs = pdfs_img.squeeze(1) * std + mean + + output_dic[fname[0]][slice_num[0]] = outputs[0] + target_dic[fname[0]][slice_num[0]] = target[0] + input_dic[fname[0]][slice_num[0]] = inputs[0] + our_nmse = nmse(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + our_psnr = psnr(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + our_ssim = ssim(target[0].cpu().numpy(), outputs[0].cpu().numpy()) + + print('name:{}, slice:{}, psnr:{}, ssim:{}'.format(name, slice_num[0], our_psnr, our_ssim)) + + + for name in output_dic.keys(): + f_output = torch.stack([v for _, v in output_dic[name].items()]) + f_target = torch.stack([v for _, v in target_dic[name].items()]) + our_nmse = nmse(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_psnr = psnr(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_ssim = ssim(f_target.cpu().numpy(), f_output.cpu().numpy()) + + nmse_meter.update(our_nmse, 1) + psnr_meter.update(our_psnr, 1) + ssim_meter.update(our_ssim, 1) + + print("==> Evaluate Metric") + print("Results ----------") + print("NMSE: {:.4}".format(np.array(nmse_meter.score).mean())) + print("PSNR: {:.4}".format(np.array(psnr_meter.score).mean())) + print("SSIM: {:.4}".format(np.array(ssim_meter.score).mean())) + print("NMSE: {:.4}".format(np.array(nmse_meter.score).std())) + print("PSNR: {:.4}".format(np.array(psnr_meter.score).std())) + print("SSIM: {:.4}".format(np.array(ssim_meter.score).std())) + print("------------------") + model.train() + return {'NMSE': nmse_meter.avg, 'PSNR': psnr_meter.avg, 'SSIM':ssim_meter.avg} + +from dataloaders.fastmri import build_dataset +if __name__ == "__main__": + ## make logger file + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + network = TwoBranch(args).cuda() + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + + + db_test = build_dataset(args, mode='val') + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + if args.phase == 'test': + + save_mode_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + print('load weights from ' + save_mode_path) + checkpoint = torch.load(save_mode_path) + + + weights_dict = {} + for k, v in checkpoint['network'].items(): + new_k = k.replace('module.', '') if 'module' in k else k + weights_dict[new_k] = v + # breakpoint() + network.load_state_dict(weights_dict) + network.eval() + + eval_result = evaluate(network, testloader, device, save_path = snapshot_path + '/result_case/') + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/train_brats.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/train_brats.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e3a6c1d7a3d73bb3e5a8347b51f6ea41084c61 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/train_brats.py @@ -0,0 +1,328 @@ +import os +import sys +from tqdm import tqdm +from tensorboardX import SummaryWriter +import shutil +import argparse +import logging +import time +import torch +import numpy as np +import torch.optim as optim +from torchvision import transforms +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision.utils import make_grid +from networks.compare_models import build_model_from_name +from dataloaders.BRATS_dataloader_new import Hybrid as MyDataset +from dataloaders.BRATS_dataloader_new import RandomPadCrop, ToTensor, AddNoise +from networks.mynet import TwoBranch +from option import args +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity + + +train_data_path = args.root_path +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +batch_size = args.batch_size * len(args.gpu.split(',')) +max_iterations = args.max_iterations +base_lr = args.base_lr + + + +def cc(img1, img2): + eps = torch.finfo(torch.float32).eps + """Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.].""" + N, C, _, _ = img1.shape + img1 = img1.reshape(N, C, -1) + img2 = img2.reshape(N, C, -1) + img1 = img1 - img1.mean(dim=-1, keepdim=True) + img2 = img2 - img2.mean(dim=-1, keepdim=True) + cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum( + img1 **2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1))) + cc = torch.clamp(cc, -1., 1.) + return cc.mean() + + + +def gradient_calllback(network): + for name, param in network.named_parameters(): + if param.grad is not None: + # print("Gradient of {}: {}".format(name, param.grad.abs().mean())) + + if param.grad.abs().mean() == 0: + print("Gradient of {} is 0".format(name)) + + else: + print("Gradient of {} is None".format(name)) + +class AMPLoss(nn.Module): + def __init__(self): + super(AMPLoss, self).__init__() + self.cri = nn.L1Loss() + + def forward(self, x, y): + x = torch.fft.rfft2(x, norm='backward') + x_mag = torch.abs(x) + y = torch.fft.rfft2(y, norm='backward') + y_mag = torch.abs(y) + + return self.cri(x_mag,y_mag) + + +class PhaLoss(nn.Module): + def __init__(self): + super(PhaLoss, self).__init__() + self.cri = nn.L1Loss() + + def forward(self, x, y): + x = torch.fft.rfft2(x, norm='backward') + x_mag = torch.angle(x) + y = torch.fft.rfft2(y, norm='backward') + y_mag = torch.angle(y) + + return self.cri(x_mag, y_mag) + +if __name__ == "__main__": + ## make logger file + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + network = TwoBranch(args).cuda() + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + # network = nn.SyncBatchNorm.convert_sync_batchnorm(network) + + n_parameters = sum(p.numel() for p in network.parameters() if p.requires_grad) + print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + db_train = MyDataset(split='train', MRIDOWN=args.MRIDOWN, SNR=args.low_field_SNR, + # transform=transforms.Compose([RandomPadCrop(), ToTensor(), AddNoise()]), + transform=transforms.Compose([RandomPadCrop(), ToTensor()]), + base_dir=train_data_path, input_normalize = args.input_normalize) + + db_test = MyDataset(split='test', MRIDOWN=args.MRIDOWN, SNR=args.low_field_SNR, + transform=transforms.Compose([ToTensor()]), + base_dir=test_data_path, input_normalize = args.input_normalize) + + trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + fixtrainloader = DataLoader(db_train, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + if args.phase == 'train': + network.train() + + params = list(network.parameters()) + optimizer1 = optim.AdamW(params, lr=base_lr, betas=(0.9, 0.999), weight_decay=1e-4) + scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=20000, gamma=0.5) + + writer = SummaryWriter(snapshot_path + '/log') + + iter_num = 0 + max_epoch = max_iterations // len(trainloader) + 1 + + + best_status = {'T1_NMSE': 10000000, 'T1_PSNR': 0, 'T1_SSIM': 0, + 'T2_NMSE': 10000000, 'T2_PSNR': 0, 'T2_SSIM': 0} + fft_weight=0.01 + criterion = nn.L1Loss().to(device, non_blocking=True) + amploss = AMPLoss().to(device, non_blocking=True) + phaloss = PhaLoss().to(device, non_blocking=True) + start_time = time.time() + + for epoch_num in tqdm(range(max_epoch), ncols=70): + time1 = time.time() + debug_time = False + + # Data Preparation Time: 0.01880049705505371 + # Network Forward Time: 0.08233189582824707 + # Loss Calculation Time: 0.08654212951660156 + # Optimizer Step Time: 0.4485752582550049 + + for i_batch, (sampled_batch, sample_stats) in enumerate(trainloader): + time2 = time.time() + + t1_in, t1, t2_in, t2 = sampled_batch['image_in'].cuda(), sampled_batch['image'].cuda(), \ + sampled_batch['target_in'].cuda(), sampled_batch['target'].cuda() + + t1_krecon, t2_krecon = sampled_batch['image_krecon'].cuda(), sampled_batch['target_krecon'].cuda() + + + time3 = time.time() + + if debug_time: + print("Data Preparation Time: ", time3 - time2) + print("t1, t2=", t1.shape, t2.shape) + + outputs = network(t2_in, t1_in) + if debug_time: + print("Network Forward Time: ", time.time() - time2) + + loss = criterion(outputs['img_out'], t2) + \ + fft_weight * amploss(outputs['img_fre'], t2) + fft_weight * phaloss( + outputs['img_fre'], + t2) + \ + criterion(outputs['img_fre'], t2) + if debug_time: + print("Loss Calculation Time: ", time.time() - time2) + + time4 = time.time() + optimizer1.zero_grad() + loss.backward() + + if args.clip_grad == "True": + ### clip the gradients to a small range. + torch.nn.utils.clip_grad_norm_(network.parameters(), 0.01) + + optimizer1.step() + scheduler1.step() + if debug_time: + print("Optimizer Step Time: ", time.time() - time2) + + time5 = time.time() + + # summary + iter_num = iter_num + 1 + # writer.add_scalar('lr', scheduler1.get_lr(), iter_num) + # writer.add_scalar('loss/loss', loss, iter_num) + + if iter_num % 100 == 0: + logging.info('iteration %d [%.2f sec]: learning rate : %f loss : %f ' % (iter_num, time.time()-start_time, scheduler1.get_lr()[0], loss.item())) + + if iter_num % 20000 == 0: + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') + torch.save({'network': network.state_dict()}, save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + + if iter_num > max_iterations: + break + time1 = time.time() + + + ## ================ Evaluate ================ + logging.info(f'Epoch {epoch_num} Evaluation:') + # print() + t1_MSE_all, t1_PSNR_all, t1_SSIM_all = [], [], [] + t2_MSE_all, t2_PSNR_all, t2_SSIM_all = [], [], [] + + t1_MSE_krecon, t1_PSNR_krecon, t1_SSIM_krecon = [], [], [] + t2_MSE_krecon, t2_PSNR_krecon, t2_SSIM_krecon = [], [], [] + + for (sampled_batch, sample_stats) in testloader: + + t1_in, t1, t2_in, t2 = sampled_batch['image_in'].cuda(), sampled_batch['image'].cuda(), \ + sampled_batch['target_in'].cuda(), sampled_batch['target'].cuda() + + t1_krecon, t2_krecon = sampled_batch['image_krecon'].cuda(), sampled_batch['target_krecon'].cuda() + t_merge = torch.cat([t1_in, t2_in], dim=1) + + t2_out = network(t2_in, t1_in)['img_out'] + t1_out = None + + if args.input_normalize == "mean_std": + t1_mean = sample_stats['t1_mean'].data.cpu().numpy()[0] + t1_std = sample_stats['t1_std'].data.cpu().numpy()[0] + t2_mean = sample_stats['t2_mean'].data.cpu().numpy()[0] + t2_std = sample_stats['t2_std'].data.cpu().numpy()[0] + + if t1_out is not None: + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_out_img = (np.clip(t1_out.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + t1_krecon_img = (np.clip(t1_krecon.data.cpu().numpy()[0, 0] * t1_std + t1_mean, 0, 1) * 255).astype(np.uint8) + + + t2_img = (np.clip(t2.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_out_img = (np.clip(t2_out.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + t2_krecon_img = (np.clip(t2_krecon.data.cpu().numpy()[0, 0] * t2_std + t2_mean, 0, 1) * 255).astype(np.uint8) + + else: + if t1_out is not None: + t1_img = (np.clip(t1.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t1_out_img = (np.clip(t1_out.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t1_krecon_img = (np.clip(t1_krecon.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + + t2_img = (np.clip(t2.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t2_out_img = (np.clip(t2_out.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + t2_krecon_img = (np.clip(t2_krecon.data.cpu().numpy()[0, 0], 0, 1) * 255).astype(np.uint8) + + + if t1_out is not None: + + MSE = mean_squared_error(t1_img, t1_out_img) + PSNR = peak_signal_noise_ratio(t1_img, t1_out_img) + SSIM = structural_similarity(t1_img, t1_out_img) + t1_MSE_all.append(MSE) + t1_PSNR_all.append(PSNR) + t1_SSIM_all.append(SSIM) + + MSE = mean_squared_error(t1_img, t1_krecon_img) + PSNR = peak_signal_noise_ratio(t1_img, t1_krecon_img) + SSIM = structural_similarity(t1_img, t1_krecon_img) + t1_MSE_krecon.append(MSE) + t1_PSNR_krecon.append(PSNR) + t1_SSIM_krecon.append(SSIM) + + + if t2_out is not None: + MSE = mean_squared_error(t2_img, t2_out_img) + PSNR = peak_signal_noise_ratio(t2_img, t2_out_img) + SSIM = structural_similarity(t2_img, t2_out_img) + t2_MSE_all.append(MSE) + t2_PSNR_all.append(PSNR) + t2_SSIM_all.append(SSIM) + # print("[t2 MRI] MSE:", MSE, "PSNR:", PSNR, "SSIM:", SSIM) + + MSE = mean_squared_error(t2_img, t2_krecon_img) + PSNR = peak_signal_noise_ratio(t2_img, t2_krecon_img) + SSIM = structural_similarity(t2_img, t2_krecon_img) + t2_MSE_krecon.append(MSE) + t2_PSNR_krecon.append(PSNR) + t2_SSIM_krecon.append(SSIM) + + if t1_out is not None: + t1_mse = np.array(t1_MSE_all).mean() + t1_psnr = np.array(t1_PSNR_all).mean() + t1_ssim = np.array(t1_SSIM_all).mean() + + t1_krecon_mse = np.array(t1_MSE_krecon).mean() + t1_krecon_psnr = np.array(t1_PSNR_krecon).mean() + t1_krecon_ssim = np.array(t1_SSIM_krecon).mean() + + t2_mse = np.array(t2_MSE_all).mean() + t2_psnr = np.array(t2_PSNR_all).mean() + t2_ssim = np.array(t2_SSIM_all).mean() + + t2_krecon_mse = np.array(t2_MSE_krecon).mean() + t2_krecon_psnr = np.array(t2_PSNR_krecon).mean() + t2_krecon_ssim = np.array(t2_SSIM_krecon).mean() + + + if t2_psnr > best_status['T2_PSNR']: + best_status = {'T2_NMSE': t2_mse, 'T2_PSNR': t2_psnr, 'T2_SSIM': t2_ssim} + + best_checkpoint_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + + torch.save({'network': network.state_dict()}, best_checkpoint_path) + print('New Best Network:') + + logging.info(f"[T2 MRI:] average MSE: {t2_mse} average PSNR: {t2_psnr} average SSIM: {t2_ssim}") + + if iter_num > max_iterations: + break + print(best_status) + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(max_iterations) + '.pth') + torch.save({'network': network.state_dict()}, + save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + writer.close() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/train_fastmri.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/train_fastmri.py new file mode 100644 index 0000000000000000000000000000000000000000..3e268812b01a003271ab21cfa0d7969f1e86ed4d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/train_fastmri.py @@ -0,0 +1,303 @@ +import os +import sys +from tqdm import tqdm +from tensorboardX import SummaryWriter +import shutil +import argparse +import logging +import time +import torch +import numpy as np +import torch.optim as optim +from torchvision import transforms +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision.utils import make_grid +from networks.compare_models import build_model_from_name +from dataloaders.BRATS_dataloader_new import Hybrid as MyDataset +from dataloaders.BRATS_dataloader_new import RandomPadCrop, ToTensor, AddNoise +from networks.mynet import TwoBranch +from option import args +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity +from dataloaders.fastmri import build_dataset + + +train_data_path = args.root_path +test_data_path = args.root_path +snapshot_path = "model/" + args.exp + "/" + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +batch_size = args.batch_size * len(args.gpu.split(',')) +max_iterations = args.max_iterations +base_lr = args.base_lr + + + +def cc(img1, img2): + eps = torch.finfo(torch.float32).eps + """Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.].""" + N, C, _, _ = img1.shape + img1 = img1.reshape(N, C, -1) + img2 = img2.reshape(N, C, -1) + img1 = img1 - img1.mean(dim=-1, keepdim=True) + img2 = img2 - img2.mean(dim=-1, keepdim=True) + cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum( + img1 **2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1))) + cc = torch.clamp(cc, -1., 1.) + return cc.mean() + + + +def gradient_calllback(network): + for name, param in network.named_parameters(): + if param.grad is not None: + if param.grad.abs().mean() == 0: + print("Gradient of {} is 0".format(name)) + + else: + print("Gradient of {} is None".format(name)) + +class AMPLoss(nn.Module): + def __init__(self): + super(AMPLoss, self).__init__() + self.cri = nn.L1Loss() + + def forward(self, x, y): + x = torch.fft.rfft2(x, norm='backward') + x_mag = torch.abs(x) + y = torch.fft.rfft2(y, norm='backward') + y_mag = torch.abs(y) + + return self.cri(x_mag,y_mag) + + +class PhaLoss(nn.Module): + def __init__(self): + super(PhaLoss, self).__init__() + self.cri = nn.L1Loss() + + def forward(self, x, y): + x = torch.fft.rfft2(x, norm='backward') + x_mag = torch.angle(x) + y = torch.fft.rfft2(y, norm='backward') + y_mag = torch.angle(y) + + return self.cri(x_mag, y_mag) + +from metric import nmse, psnr, ssim, AverageMeter +from collections import defaultdict +@torch.no_grad() +def evaluate(model, data_loader, device): + model.eval() + nmse_meter = AverageMeter() + psnr_meter = AverageMeter() + ssim_meter = AverageMeter() + output_dic = defaultdict(dict) + target_dic = defaultdict(dict) + input_dic = defaultdict(dict) + + for id, data in enumerate(data_loader): + pd, pdfs, _ = data + target = pdfs[1] + + mean = pdfs[2] + std = pdfs[3] + + # print("get mean and std:", mean, std) + + fname = pdfs[4] + slice_num = pdfs[5] + + mean = mean.unsqueeze(1).unsqueeze(2) + std = std.unsqueeze(1).unsqueeze(2) + + mean = mean.to(device) + std = std.to(device) + + + + pd_img = pd[1].unsqueeze(1) + pdfs_img = pdfs[0].unsqueeze(1) + + pd_img = pd_img.to(device) + pdfs_img = pdfs_img.to(device) + target = target.to(device) + + + outputs = model(pdfs_img, pd_img)['img_out'] + outputs = outputs.squeeze(1) + + # print("outputs shape:", outputs.shape, outputs.min(), outputs.max()) + + outputs = outputs * std + mean + target = target * std + mean + inputs = pdfs_img.squeeze(1) * std + mean + + # print("Ourputs after denormalization:", outputs.min(), outputs.max()) + + for i, f in enumerate(fname): + output_dic[f][slice_num[i]] = outputs[i] + target_dic[f][slice_num[i]] = target[i] + input_dic[f][slice_num[i]] = inputs[i] + + if id > 50: + break + + for name in output_dic.keys(): + f_output = torch.stack([v for _, v in output_dic[name].items()]) + f_target = torch.stack([v for _, v in target_dic[name].items()]) + our_nmse = nmse(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_psnr = psnr(f_target.cpu().numpy(), f_output.cpu().numpy()) + our_ssim = ssim(f_target.cpu().numpy(), f_output.cpu().numpy()) + + + nmse_meter.update(our_nmse, 1) + psnr_meter.update(our_psnr, 1) + ssim_meter.update(our_ssim, 1) + + print("==> Evaluate Metric") + print("Results ----------") + print("NMSE: {:.4}".format(nmse_meter.avg)) + print("PSNR: {:.4}".format(psnr_meter.avg)) + print("SSIM: {:.4}".format(ssim_meter.avg)) + print("------------------") + model.train() + + return {'NMSE': nmse_meter.avg, 'PSNR': psnr_meter.avg, 'SSIM':ssim_meter.avg} + + + +if __name__ == "__main__": + ## make logger file + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + + logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, + format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + + network = TwoBranch(args).cuda() + # network = build_model_from_name(args).cuda() + device = torch.device('cuda') + network.to(device) + + if len(args.gpu.split(',')) > 1: + network = nn.DataParallel(network) + # network = nn.SyncBatchNorm.convert_sync_batchnorm(network) + + n_parameters = sum(p.numel() for p in network.parameters() if p.requires_grad) + print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + db_train = build_dataset(args, mode='train') + db_test = build_dataset(args, mode='val') + + trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) + + if args.phase == 'train': + network.train() + + params = list(network.parameters()) + optimizer1 = optim.AdamW(params, lr=base_lr, betas=(0.9, 0.999), weight_decay=1e-4) + scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=20000, gamma=0.5) + + writer = SummaryWriter(snapshot_path + '/log') + + iter_num = 0 + max_epoch = max_iterations // len(trainloader) + 1 + + + best_status = {'NMSE': 10000000, 'PSNR': 0, 'SSIM': 0} + fft_weight=0.01 + criterion = nn.L1Loss().to(device, non_blocking=True) + amploss = AMPLoss().to(device, non_blocking=True) + phaloss = PhaLoss().to(device, non_blocking=True) + for epoch_num in tqdm(range(max_epoch), ncols=70): + time1 = time.time() + for i_batch, sampled_batch in enumerate(trainloader): + time2 = time.time() + # print("time for data loading:", time2 - time1) + + pd, pdfs, _ = sampled_batch + target = pdfs[1] + + + mean = pdfs[2] + std = pdfs[3] + + + # print("mean:", mean, "std:", std) + + pd_img = pd[1].unsqueeze(1) + pdfs_img = pdfs[0].unsqueeze(1) + target = target.unsqueeze(1) + + pd_img = pd_img.to(device) # [4, 1, 320, 320] + pdfs_img = pdfs_img.to(device) # [4, 1, 320, 320] + target = target.to(device) # [4, 1, 320, 320] + + time3 = time.time() + # breakpoint() + outputs = network(pdfs_img, pd_img) + + loss = criterion(outputs['img_out'], target) + \ + fft_weight * amploss(outputs['img_fre'], target) + fft_weight * phaloss( + outputs['img_fre'], + target) + \ + criterion(outputs['img_fre'], target) + + time4 = time.time() + + + optimizer1.zero_grad() + loss.backward() + + if args.clip_grad == "True": + ### clip the gradients to a small range. + torch.nn.utils.clip_grad_norm_(network.parameters(), 0.01) + + optimizer1.step() + scheduler1.step() + + time5 = time.time() + + # summary + iter_num = iter_num + 1 + + if iter_num % 100 == 0: + logging.info('iteration %d : learning rate : %f loss : %f ' % (iter_num, scheduler1.get_lr()[0], loss.item())) + break + + if iter_num % 20000 == 0: + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') + torch.save({'network': network.state_dict()}, save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + + if iter_num > max_iterations: + break + time1 = time.time() + + + ## ================ Evaluate ================ + logging.info(f'Epoch {epoch_num} Evaluation:') + # print() + eval_result = evaluate(network, testloader, device) + + if eval_result['PSNR'] > best_status['PSNR']: + best_status = {'NMSE': eval_result['NMSE'], 'PSNR': eval_result['PSNR'], 'SSIM': eval_result['SSIM']} + best_checkpoint_path = os.path.join(snapshot_path, 'best_checkpoint.pth') + + torch.save({'network': network.state_dict()}, best_checkpoint_path) + print('New Best Network:') + logging.info(f"average MSE: {eval_result['NMSE']} average PSNR: {eval_result['PSNR']} average SSIM: {eval_result['SSIM']}") + + if iter_num > max_iterations: + break + print(best_status) + save_mode_path = os.path.join(snapshot_path, 'iter_' + str(max_iterations) + '.pth') + torch.save({'network': network.state_dict()}, + save_mode_path) + logging.info("save model to {}".format(save_mode_path)) + writer.close() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/utils.py b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f733ac3d3d6527cae765d48cf58b0c02167532 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/experiments/FSMNet/utils.py @@ -0,0 +1,33 @@ +import numpy as np +import torch + + +def bright(x, a,b): + # input datatype np.uint8 + x = np.array(x, dtype='float') + x = x/(b-a) - 255*a/(b-a) + x[x>255.0] = 255.0 + x[x<0.0] = 0.0 + x = x.astype(np.uint8) + return x + +def trunc(x): + # input datatype float + x[x>255.0] = 255.0 + x[x<0.0] = 0.0 + return x + + + +def cc(img1, img2): + eps = torch.finfo(torch.float32).eps + """Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.].""" + N, C, _, _ = img1.shape + img1 = img1.reshape(N, C, -1) + img2 = img2.reshape(N, C, -1) + img1 = img1 - img1.mean(dim=-1, keepdim=True) + img2 = img2 - img2.mean(dim=-1, keepdim=True) + cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum( + img1 **2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1))) + cc = torch.clamp(cc, -1., 1.) + return cc.mean() \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/main.py b/MRI_recon/new_code/Frequency-Diffusion-main/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e98f1a3b2df63dd5a49387208aaca57dc493a0f0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/main.py @@ -0,0 +1,249 @@ +import torchvision +import os +import errno +import shutil +import argparse +from networks import TwoBranchModel,Unet +from diffusion_pytorch import GaussianDiffusion, Trainer +import torch, warnings + +from pytorch_lightning.callbacks import Callback +warnings.filterwarnings("ignore") + + +class DebugDataloaderCallback(Callback): + # + def __init__(self): + super().__init__() + self.counter = 0 + + def on_train_start(self, trainer, pl_module): + self.counter += 1 + if (self.counter + 1 ) % 10 == 0: + trainer.train_dataloader.dataset.update_chunk() + + + +def create_folder(path): + try: + os.mkdir(path) + except OSError as exc: + if exc.errno != errno.EEXIST: + raise + pass + + +def del_folder(path): + try: + shutil.rmtree(path) + except OSError as exc: + pass + + +create = 0 + +if create: + trainset = torchvision.datasets.CIFAR10( + root='./data', train=True, download=True) + root = './root_cifar10/' + del_folder(root) + create_folder(root) + + for i in range(10): + lable_root = root + str(i) + '/' + create_folder(lable_root) + + for idx in range(len(trainset)): + img, label = trainset[idx] + print(idx) + img.save(root + str(label) + '/' + str(idx) + '.png') + + +parser = argparse.ArgumentParser() +parser.add_argument('--time_steps', default=50, type=int) +parser.add_argument('--train_steps', default=700000, type=int) +parser.add_argument('--save_folder', default=None, type=str) + +parser.add_argument('--load_path', default=None, type=str) +parser.add_argument('--data_path', default='./root_cifar10/', type=str) +parser.add_argument('--fade_routine', default='Random_Incremental', type=str) +parser.add_argument('--sampling_routine', default='x0_step_down', type=str) +parser.add_argument('--discrete', action="store_true") +parser.add_argument('--remove_time_embed', action="store_true") +parser.add_argument('--residual', action="store_true") +parser.add_argument('--tag', default='', type=str) +parser.add_argument('--accelerate_factor', default=4, help="4 | 8", type=int) + + +parser.add_argument('--normalizer', default='mean_std', type=str) + +parser.add_argument('--mode', default='train', type=str) +parser.add_argument('--example_frequency_img', default=None, type=str) +# specific arguments +# parser.add_argument('--initial_mask', default=11, type=int) +parser.add_argument('--kernel_std', default=0.1, type=float) + +parser.add_argument('--dataset', default='brain', type=str) +parser.add_argument('--domain', default=None, type=str) +parser.add_argument('--aux_modality', default=None, type=str) +parser.add_argument('--deviceid', default=0, type=int) +parser.add_argument('--num_channels', default=1, type=int) +parser.add_argument('--train_bs', default=24, type=int) +parser.add_argument('--diffusion_type', default='twobranch_fade', type=str) +parser.add_argument('--debug', action="store_true") +parser.add_argument('--image_size', default=128) +parser.add_argument('--loss_type', default='l1', type=str) + +args = parser.parse_args() +print(args) +os.environ["CUDA_VISIBLE_DEVICES"] = str(args.deviceid) + +image_channels = 1 + +diffusion_type = args.diffusion_type +# diffusion_type = "twobranch_fade" # model_degradation # fade | kspace +model_name = diffusion_type.split("_")[0] # unet | twobranch + +save_and_sample_every = 1000 + +if args.debug: + args.train_steps = 100 + args.time_steps = 5 + +model = None + + +if isinstance(args.image_size, str): + length = len(args.image_size.split(",")) + if length == 1: + args.image_size = (int(args.image_size), int(args.image_size)) + elif length == 2: + args.image_size = (int(args.image_size.split(",")[0]), int(args.image_size.split(",")[1])) +else: + args.image_size = (args.image_size, args.image_size) + + + +if model_name == "unet": + model = Unet(resolution=args.image_size[0], + in_channels=1, + out_ch=1, + ch=128, + ch_mult=(1, 2, 2, 2), + num_res_blocks=2, + attn_resolutions=(16,), + dropout=0.1).cuda() + +elif model_name == "twounet": + model = TwoBranchNewModel(resolution=args.image_size[0], + in_channels=1, + out_ch=1, + ch=128, + ch_mult=(1, 2, 2, 2), + num_res_blocks=3, + attn_resolutions=(16,), + dropout=0.1).cuda() # Drop out used to be 0.1 + + +elif model_name == "twobranch": + + base_num_every_group = 2 + num_features = 64 + act = "PReLU" + num_channels = 1 + + from networks.networks_fsm.mynet import TwoBranch as TwoBranchModel + + + model = TwoBranchModel( + num_features, act, base_num_every_group, num_channels + ).cuda() + +fp16 = False + + +n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) +print('number of params: %.2f M' % (n_parameters / 1024 / 1024)) + + +diffusion = GaussianDiffusion( + diffusion_type, + model, + image_size=args.image_size[0], # Used to be 32 + channels=image_channels, + device_of_kernel='cuda', + timesteps=args.time_steps, + loss_type=args.loss_type, #$'l1', + kernel_std=args.kernel_std, + fade_routine=args.fade_routine, + sampling_routine=args.sampling_routine, + discrete=args.discrete, + accelerate_factor=args.accelerate_factor, + fp16=fp16, + normalizer=args.normalizer, + example_frequency_img=args.example_frequency_img, +).cuda() + + +diffusion = torch.nn.DataParallel(diffusion, device_ids=range(torch.cuda.device_count())) + +print("=== train_steps:", args.train_steps) +os.makedirs(args.save_folder, exist_ok=True) + +if args.debug: + args.save_folder = args.save_folder + "_debug" +else: + args.save_folder = args.save_folder + f"_{args.tag}" + save_and_sample_every = 500 + + +# if os.path.exists(args.save_folder): +name = args.save_folder.split("/")[-1] +number = os.listdir(args.save_folder.rstrip(name)).__len__() +if args.mode == "test": + number = "test_" + str(number) + +args.save_folder = os.path.join(args.save_folder.rstrip(name), f"{number}_" + name) + +# create the folder and parent folders +os.makedirs(args.save_folder, exist_ok=True) + + + +print("SAVE FOLDER: ", args.save_folder) + +trainer = Trainer( + diffusion, + args.data_path, + mode = args.mode, + norm = args.normalizer, + image_size=args.image_size, # Used to be 32 + train_batch_size=args.train_bs, + train_lr= 1e-4, # 2e-5 + train_num_steps=args.train_steps, + gradient_accumulate_every=1, + ema_decay=0.995, + save_and_sample_every=save_and_sample_every, + fp16=fp16, + results_folder=args.save_folder, + load_path=args.load_path, + dataset=args.dataset, + domain=args.domain, + aux_modality=args.aux_modality, + debug=args.debug, + num_channels=args.num_channels + # accelerator="gpu", + # callbacks=[DebugDataloaderCallback()], +) + + +if args.mode == "train": + trainer.train() + +elif args.mode == "test": + # ['default', 'x0_step_down', 'x0_step_down_fre', "fre_progressive"]: + trainer.test_loader('x0_step_down_fre') + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/metrics/fid.py b/MRI_recon/new_code/Frequency-Diffusion-main/metrics/fid.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa3918b97302d209d4fb2f88dc50ca7ef1476b5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/metrics/fid.py @@ -0,0 +1,334 @@ +import torch +from torch import nn +from torchvision.models import inception_v3 +import cv2 +import multiprocessing +import numpy as np +import glob +import os +from scipy import linalg + + +def to_cuda(elements): + """ + Transfers elements to cuda if GPU is available + Args: + elements: torch.tensor or torch.nn.module + -- + Returns: + elements: same as input on GPU memory, if available + """ + if torch.cuda.is_available(): + return elements.cuda() + return elements + + +class PartialInceptionNetwork(nn.Module): + + def __init__(self, transform_input=True): + super().__init__() + self.inception_network = inception_v3(pretrained=True) + self.inception_network.Mixed_7c.register_forward_hook(self.output_hook) + self.transform_input = transform_input + + def output_hook(self, module, input, output): + # N x 2048 x 8 x 8 + self.mixed_7c_output = output + + def forward(self, x): + """ + Args: + x: shape (N, 3, 299, 299) dtype: torch.float32 in range 0-1 + Returns: + inception activations: torch.tensor, shape: (N, 2048), dtype: torch.float32 + """ + assert x.shape[1:] == (3, 299, 299), "Expected input shape to be: (N,3,299,299)" + \ + ", but got {}".format(x.shape) + x = x * 2 - 1 # Normalize to [-1, 1] + + # Trigger output hook + self.inception_network(x) + + # Output: N x 2048 x 1 x 1 + activations = self.mixed_7c_output + activations = torch.nn.functional.adaptive_avg_pool2d(activations, (1, 1)) + activations = activations.view(x.shape[0], 2048) + return activations + +class PartialResnet3D(nn.Module): + + def __init__(self, transform_input=True): + super().__init__() + model = torch.hub.load('facebookresearch/pytorchvideo', 'slow_r50', + pretrained=True) + + model.blocks[5].proj = nn.Identity() + model.blocks[5].output_pool = nn.Identity() + + self.network = model + + # input = torch.ones(1, 3, 8, 256, 256) + + + # self.inception_network.Mixed_7c.register_forward_hook(self.output_hook) + self.transform_input = transform_input + + # def output_hook(self, module, input, output): + # N x 2048 x 8 x 8 + # self.mixed_7c_output = output + + def forward(self, x): + """ + Args: + x: shape (N, 3, 299, 299) dtype: torch.float32 in range 0-1 + Returns: + inception activations: torch.tensor, shape: (N, 2048), dtype: torch.float32 + """ + # assert x.shape[1:] == (3, 256, 256), "Expected input shape to be: (N,3,299,299)" + \ + # ", but got {}".format(x.shape) + + x = x * 2 - 1 # Normalize to [-1, 1] + + # Trigger output hook + activations = self.inception_network(x) + + + # Output: N x 2048 x 1 x 1 + # activations = self.mixed_7c_output + activations = torch.nn.functional.adaptive_avg_pool2d(activations, (1, 1)) + activations = activations.view(x.shape[0], 2048) + return activations + + + +def get_activations(images, batch_size): + """ + Calculates activations for last pool layer for all iamges + -- + Images: torch.array shape: (N, 3, 299, 299), dtype: torch.float32 + batch size: batch size used for inception network + -- + Returns: np array shape: (N, 2048), dtype: np.float32 + """ + assert images.shape[1:] == (3, 299, 299), "Expected input shape to be: (N,3,299,299)" + \ + ", but got {}".format(images.shape) + + num_images = images.shape[0] + inception_network = PartialInceptionNetwork() + inception_network = to_cuda(inception_network) + inception_network.eval() + n_batches = int(np.ceil(num_images / batch_size)) + inception_activations = np.zeros((num_images, 2048), dtype=np.float32) + for batch_idx in range(n_batches): + start_idx = batch_size * batch_idx + end_idx = batch_size * (batch_idx + 1) + + ims = images[start_idx:end_idx] + ims = to_cuda(ims) + activations = inception_network(ims) + activations = activations.detach().cpu().numpy() + assert activations.shape == (ims.shape[0], 2048), "Expexted output shape to be: {}, but was: {}".format( + (ims.shape[0], 2048), activations.shape) + inception_activations[start_idx:end_idx, :] = activations + return inception_activations + + +def calculate_activation_statistics(images, batch_size): + """Calculates the statistics used by FID + Args: + images: torch.tensor, shape: (N, 3, H, W), dtype: torch.float32 in range 0 - 1 + batch_size: batch size to use to calculate inception scores + Returns: + mu: mean over all activations from the last pool layer of the inception model + sigma: covariance matrix over all activations from the last pool layer + of the inception model. + + """ + act = get_activations(images, batch_size) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +# Modified from: https://github.com/bioinf-jku/TTUR/blob/master/fid.py +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of the pool_3 layer of the + inception net ( like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted + on an representive data set. + -- sigma1: The covariance matrix over activations of the pool_3 layer for + generated samples. + -- sigma2: The covariance matrix over activations of the pool_3 layer, + precalcualted on an representive data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" + assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + # product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps + warnings.warn(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + +def preprocess_image(im): + """Resizes and shifts the dynamic range of image to 0-1 + Args: + im: np.array, shape: (H, W, 3), dtype: float32 between 0-1 or np.uint8 + Return: + im: torch.tensor, shape: (3, 299, 299), dtype: torch.float32 between 0-1 + """ + # print("im shape:", im.shape) + if im.shape[0] == 3: + im = im.transpose(1, 2, 0) + # CHW->HWC + + # print("new im shape:", im.shape) + + + assert im.shape[2] == 3 + assert len(im.shape) == 3 + if im.dtype == np.uint8: + im = im.astype(np.float32) / 255 + + im = cv2.resize(im, (299, 299)) + im = np.rollaxis(im, axis=2) + im = torch.from_numpy(im) + assert im.max() <= 1.0 + assert im.min() >= 0.0 + assert im.dtype == torch.float32 + assert im.shape == (3, 299, 299) + + return im + + +def preprocess_images(images, use_multiprocessing): + """Resizes and shifts the dynamic range of image to 0-1 + Args: + images: np.array, shape: (N, H, W, 3), dtype: float32 between 0-1 or np.uint8 + use_multiprocessing: If multiprocessing should be used to pre-process the images + Return: + final_images: torch.tensor, shape: (N, 3, 299, 299), dtype: torch.float32 between 0-1 + """ + if use_multiprocessing: + with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: + jobs = [] + for im in images: + job = pool.apply_async(preprocess_image, (im,)) + jobs.append(job) + final_images = torch.zeros(images.shape[0], 3, 299, 299) + for idx, job in enumerate(jobs): + im = job.get() + final_images[idx] = im # job.get() + else: + final_images = torch.stack([preprocess_image(im) for im in images], dim=0) + + assert final_images.shape == (images.shape[0], 3, 299, 299) + assert final_images.max() <= 1.0 + assert final_images.min() >= 0.0 + assert final_images.dtype == torch.float32 + return final_images + + +def calculate_fid(images1, images2, use_multiprocessing, batch_size): + """ Calculate FID between images1 and images2 + Args: + images1: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8 + images2: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8 + use_multiprocessing: If multiprocessing should be used to pre-process the images + batch size: batch size used for inception network + Returns: + FID (scalar) + """ + images1 = preprocess_images(images1, use_multiprocessing) + images2 = preprocess_images(images2, use_multiprocessing) + mu1, sigma1 = calculate_activation_statistics(images1, batch_size) + mu2, sigma2 = calculate_activation_statistics(images2, batch_size) + fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2) + return fid + + +def load_images(path): + """ Loads all .png or .jpg images from a given path + Warnings: Expects all images to be of same dtype and shape. + Args: + path: relative path to directory + Returns: + final_images: np.array of image dtype and shape. + """ + image_paths = [] + image_extensions = ["png", "jpg"] + for ext in image_extensions: + print("Looking for images in", os.path.join(path, "*.{}".format(ext))) + for impath in glob.glob(os.path.join(path, "*.{}".format(ext))): + image_paths.append(impath) + first_image = cv2.imread(image_paths[0]) + W, H = first_image.shape[:2] + image_paths.sort() + image_paths = image_paths + final_images = np.zeros((len(image_paths), H, W, 3), dtype=first_image.dtype) + for idx, impath in enumerate(image_paths): + im = cv2.imread(impath) + im = im[:, :, ::-1] # Convert from BGR to RGB + assert im.dtype == final_images.dtype + final_images[idx] = im + return final_images + + +if __name__ == "__main__": + from optparse import OptionParser + + parser = OptionParser() + parser.add_option("--p1", "--path1", dest="path1", + help="Path to directory containing the real images") + parser.add_option("--p2", "--path2", dest="path2", + help="Path to directory containing the generated images") + parser.add_option("--multiprocessing", dest="use_multiprocessing", + help="Toggle use of multiprocessing for image pre-processing. Defaults to use all cores", + default=False, + action="store_true") + parser.add_option("-b", "--batch-size", dest="batch_size", + help="Set batch size to use for InceptionV3 network", + type=int) + + options, _ = parser.parse_args() + assert options.path1 is not None, "--path1 is an required option" + assert options.path2 is not None, "--path2 is an required option" + assert options.batch_size is not None, "--batch_size is an required option" + images1 = load_images(options.path1) + images2 = load_images(options.path2) + fid_value = calculate_fid(images1, images2, options.use_multiprocessing, options.batch_size) + print(fid_value) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/metrics/fid_3d.py b/MRI_recon/new_code/Frequency-Diffusion-main/metrics/fid_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..054c7ed99d5aa823814bed42c4de921fb8cdf56b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/metrics/fid_3d.py @@ -0,0 +1,307 @@ +import torch +from torch import nn +from torchvision.models import inception_v3 +import cv2 +import multiprocessing +import numpy as np +import glob +import os +from scipy import linalg + + +def to_cuda(elements): + """ + Transfers elements to cuda if GPU is available + Args: + elements: torch.tensor or torch.nn.module + -- + Returns: + elements: same as input on GPU memory, if available + """ + if torch.cuda.is_available(): + return elements.cuda() + return elements + + + + +class PartialResnet3D(nn.Module): + + def __init__(self, transform_input=True): + super().__init__() + model = torch.hub.load('facebookresearch/pytorchvideo', 'slow_r50', + pretrained=True) + + model.blocks[5].proj = nn.Identity() + model.blocks[5].output_pool = nn.Identity() + + self.network = model + + # input = torch.ones(1, 3, 8, 256, 256) + + # self.inception_network.Mixed_7c.register_forward_hook(self.output_hook) + self.transform_input = transform_input + + # def output_hook(self, module, input, output): + # N x 98304 x 8 x 8 + # self.mixed_7c_output = output + + def forward(self, x): + """ + Args: + x: shape (N, 3, 299, 299) dtype: torch.float32 in range 0-1 + Returns: + inception activations: torch.tensor, shape: (N, 98304), dtype: torch.float32 + """ + # assert x.shape[1:] == (3, 256, 256), "Expected input shape to be: (N,3,299,299)" + \ + # ", but got {}".format(x.shape) + + x = x * 2 - 1 # Normalize to [-1, 1] + + # Trigger output hook + activations = self.network(x) + # print("activations shape:", activations.shape) # activations shape: torch.Size([1, 98304]) + + # Output: N x 98304 x 1 x 1 + # activations = self.mixed_7c_output + # activations = torch.nn.functional.adaptive_avg_pool2d(activations, (1, 1)) + activations = activations.view(x.shape[0], 98304) + return activations + + +def get_activations(images, batch_size): + """ + Calculates activations for last pool layer for all iamges + -- + Images: torch.array shape: (N, 3, 299, 299), dtype: torch.float32 + batch size: batch size used for inception network + -- + Returns: np array shape: (N, 98304), dtype: np.float32 + """ + # assert images.shape[1:] == (3, 299, 299), "Expected input shape to be: (N,3,299,299)" + \ + # ", but got {}".format(images.shape) + + num_images = images.shape[0] + inception_network = PartialResnet3D() + inception_network = to_cuda(inception_network) + inception_network.eval() + n_batches = int(np.ceil(num_images / batch_size)) + inception_activations = np.zeros((num_images, 98304), dtype=np.float32) + for batch_idx in range(n_batches): + start_idx = batch_size * batch_idx + end_idx = batch_size * (batch_idx + 1) + + ims = images[start_idx:end_idx] + ims = to_cuda(ims) + activations = inception_network(ims) + activations = activations.detach().cpu().numpy() + assert activations.shape == (ims.shape[0], 98304), "Expexted output shape to be: {}, but was: {}".format( + (ims.shape[0], 98304), activations.shape) + inception_activations[start_idx:end_idx, :] = activations + return inception_activations + + +def calculate_activation_statistics(images, batch_size): + """Calculates the statistics used by FID + Args: + images: torch.tensor, shape: (N, 3, H, W), dtype: torch.float32 in range 0 - 1 + batch_size: batch size to use to calculate inception scores + Returns: + mu: mean over all activations from the last pool layer of the inception model + sigma: covariance matrix over all activations from the last pool layer + of the inception model. + + """ + act = get_activations(images, batch_size) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +# Modified from: https://github.com/bioinf-jku/TTUR/blob/master/fid.py +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of the pool_3 layer of the + inception net ( like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted + on an representive data set. + -- sigma1: The covariance matrix over activations of the pool_3 layer for + generated samples. + -- sigma2: The covariance matrix over activations of the pool_3 layer, + precalcualted on an representive data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" + assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + # product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps + warnings.warn(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + +def preprocess_image(im): + """Resizes and shifts the dynamic range of image to 0-1 + Args: + im: np.array, shape: (H, W, 3), dtype: float32 between 0-1 or np.uint8 + Return: + im: torch.tensor, shape: (3, 299, 299), dtype: torch.float32 between 0-1 + """ + # print("im shape:", im.shape) + if im.shape[0] == 3: + im = im.transpose(1, 2, 0) + # CHW->HWC + + # print("new im shape:", im.shape) + + assert im.shape[2] == 3 + assert len(im.shape) == 3 + if im.dtype == np.uint8: + im = im.astype(np.float32) / 255 + + im = cv2.resize(im, (299, 299)) + im = np.rollaxis(im, axis=2) + im = torch.from_numpy(im) + assert im.max() <= 1.0 + assert im.min() >= 0.0 + assert im.dtype == torch.float32 + assert im.shape == (3, 299, 299) + + return im + + +def preprocess_images(images, use_multiprocessing): + """Resizes and shifts the dynamic range of image to 0-1 + Args: + images: np.array, shape: (N, H, W, 3), dtype: float32 between 0-1 or np.uint8 + use_multiprocessing: If multiprocessing should be used to pre-process the images + Return: + final_images: torch.tensor, shape: (N, 3, 299, 299), dtype: torch.float32 between 0-1 + """ + if use_multiprocessing: + with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: + jobs = [] + for im in images: + job = pool.apply_async(preprocess_image, (im,)) + jobs.append(job) + final_images = torch.zeros(images.shape[0], 3, 256, 256) + for idx, job in enumerate(jobs): + im = job.get() + final_images[idx] = im # job.get() + else: + final_images = torch.stack([preprocess_image(im) for im in images], dim=0) + + # print("final_images shape:", final_images.shape) + # assert final_images.shape == (1, 3, images.shape[0], 256, 256) + assert final_images.max() <= 1.0 + assert final_images.min() >= 0.0 + assert final_images.dtype == torch.float32 + return final_images + + +def calculate_fid_3d(images1, images2, use_multiprocessing, batch_size): + """ Calculate FID between images1 and images2 + Args: + images1: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8 + images2: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8 + use_multiprocessing: If multiprocessing should be used to pre-process the images + batch size: batch size used for inception network + Returns: + FID (scalar) + """ + images1 = preprocess_images(images1, use_multiprocessing) + images2 = preprocess_images(images2, use_multiprocessing) + + # C, 3, H, W -> 1, 3, C, H, W + images1 = images1.unsqueeze(0).permute(0, 2, 1, 3, 4) + images2 = images2.unsqueeze(0).permute(0, 2, 1, 3, 4) + + mu1, sigma1 = calculate_activation_statistics(images1, batch_size) + mu2, sigma2 = calculate_activation_statistics(images2, batch_size) + fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2) + return fid + + +def load_images(path): + """ Loads all .png or .jpg images from a given path + Warnings: Expects all images to be of same dtype and shape. + Args: + path: relative path to directory + Returns: + final_images: np.array of image dtype and shape. + """ + image_paths = [] + image_extensions = ["png", "jpg"] + for ext in image_extensions: + print("Looking for images in", os.path.join(path, "*.{}".format(ext))) + for impath in glob.glob(os.path.join(path, "*.{}".format(ext))): + image_paths.append(impath) + first_image = cv2.imread(image_paths[0]) + W, H = first_image.shape[:2] + image_paths.sort() + image_paths = image_paths + final_images = np.zeros((len(image_paths), H, W, 3), dtype=first_image.dtype) + for idx, impath in enumerate(image_paths): + im = cv2.imread(impath) + im = im[:, :, ::-1] # Convert from BGR to RGB + assert im.dtype == final_images.dtype + final_images[idx] = im + return final_images + + +if __name__ == "__main__": + from optparse import OptionParser + + parser = OptionParser() + parser.add_option("--p1", "--path1", dest="path1", + help="Path to directory containing the real images") + parser.add_option("--p2", "--path2", dest="path2", + help="Path to directory containing the generated images") + parser.add_option("--multiprocessing", dest="use_multiprocessing", + help="Toggle use of multiprocessing for image pre-processing. Defaults to use all cores", + default=False, + action="store_true") + parser.add_option("-b", "--batch-size", dest="batch_size", + help="Set batch size to use for InceptionV3 network", + type=int) + + options, _ = parser.parse_args() + assert options.path1 is not None, "--path1 is an required option" + assert options.path2 is not None, "--path2 is an required option" + assert options.batch_size is not None, "--batch_size is an required option" + images1 = load_images(options.path1) + images2 = load_images(options.path2) + fid_value = calculate_fid(images1, images2, options.use_multiprocessing, options.batch_size) + print(fid_value) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/metrics/frequency_loss.py b/MRI_recon/new_code/Frequency-Diffusion-main/metrics/frequency_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..642abd11d2dc3af909744b6fff3ac8463bf0f5ad --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/metrics/frequency_loss.py @@ -0,0 +1,43 @@ +import torch + + + +class AMPLoss(torch.nn.Module): + def __init__(self, epsilon=1e-8, loss="l1", norm='backward'): + super(AMPLoss, self).__init__() + self.mask_region = False # TODO + + if loss == "l1": + self.cri = torch.nn.L1Loss(reduction="sum" if self.mask_region else "mean") + else: + self.cri = torch.nn.MSELoss(reduction="sum" if self.mask_region else "mean") + + self.epsilon = epsilon # To prevent division by zero + self.norm = norm # Normalization for FFT + + def forward(self, x, y, k): + # Perform FFT and compute magnitudes + x_fft = torch.fft.rfft2(x, norm=self.norm) + y_fft = torch.fft.rfft2(y, norm=self.norm) + + x_mag = torch.clamp(torch.abs(x_fft), min=self.epsilon) # Clamp to avoid zeros + y_mag = torch.clamp(torch.abs(y_fft), min=self.epsilon) # Clamp to avoid zeros + + x_phase = torch.angle(x_fft) + y_phase = torch.angle(y_fft) + + if self.mask_region: + W = x.shape[-1] + k = (1 - k.to(x.device)) + k = k[..., :W // 2 + 1] + k_total = torch.sum(k) + + x_mag = x_mag * k + y_mag = y_mag * k + x_phase = x_phase * k + y_phase = y_phase * k + # Compute L1 loss between magnitudes + return self.cri(x_mag, y_mag)/k_total + self.cri(x_phase, y_phase)/k_total + else: + # Compute L1 loss between magnitudes + return self.cri(x_mag, y_mag) + self.cri(x_phase, y_phase) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/metrics/lpips.py b/MRI_recon/new_code/Frequency-Diffusion-main/metrics/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a30b875fb4aa39ccd8419759d2f841d62bbad6 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/metrics/lpips.py @@ -0,0 +1,184 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" + +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + + +from collections import namedtuple +from torchvision import models +import torch.nn as nn +import torch +from tqdm import tqdm +import requests +import os +import hashlib +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format( + name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + self.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name is not "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + model.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + input = input.float() + target = target.float() + + in0_input, in1_input = (self.scaling_layer( + input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor( + outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor( + [-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor( + [.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, + padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, + h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/metrics/nmse.py b/MRI_recon/new_code/Frequency-Diffusion-main/metrics/nmse.py new file mode 100644 index 0000000000000000000000000000000000000000..790122086edaf81b7c4e268adacb1fff6b0ce3a9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/metrics/nmse.py @@ -0,0 +1,5 @@ +import numpy as np + +def nmse(gt, pred): + """Compute Normalized Mean Squared Error (NMSE)""" + return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2 \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/Unet.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/Unet.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c154dd8c2912e0665589496ce1154080ecc1a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/Unet.py @@ -0,0 +1,332 @@ +import math +import torch +import torch.nn as nn + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + + + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t): + assert x.shape[2] == x.shape[3] == self.resolution + + # timestep embedding + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + + # print(t) + # print(temb) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f42d964b4b0e18ce8995b50019d171bd2340ad --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/__init__.py @@ -0,0 +1,2 @@ +from .Unet import Model as Unet +from .st_branch_model.model import TwoBranchModel \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/ARTUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/ARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ae23fed16b0d9454a5ea09f17b4322f1e9d7e928 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/ARTUnet.py @@ -0,0 +1,751 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + # print("x shape:", x.shape) + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +########################################################################## +class CS_TransformerBlock(nn.Module): + r""" 在ART Transformer Block的基础上添加了channel attention的分支。 + 参考论文: Dual Aggregation Transformer for Image Super-Resolution + 及其代码: https://github.com/zhengchen1999/DAT/blob/main/basicsr/archs/dat_arch.py + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + + self.dwconv = nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim), + nn.InstanceNorm2d(dim), + nn.GELU()) + + self.channel_interaction = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(dim, dim // 8, kernel_size=1), + # nn.InstanceNorm2d(dim // 8), + nn.GELU(), + nn.Conv2d(dim // 8, dim, kernel_size=1)) + + self.spatial_interaction = nn.Sequential( + nn.Conv2d(dim, dim // 16, kernel_size=1), + nn.InstanceNorm2d(dim // 16), + nn.GELU(), + nn.Conv2d(dim // 16, 1, kernel_size=1)) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # print("x before dwconv:", x.shape) + # convolution output + conv_x = self.dwconv(x.permute(0, 3, 1, 2)) + # conv_x = x.permute(0, 3, 1, 2) + # print("x after dwconv:", x.shape) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + attened_x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + attened_x = attened_x[:, :H, :W, :].contiguous() + + attened_x = attened_x.view(B, H * W, C) + + # Adaptive Interaction Module (AIM) + # C-Map (before sigmoid) + channel_map = self.channel_interaction(conv_x).permute(0, 2, 3, 1).contiguous().view(B, 1, C) + # S-Map (before sigmoid) + attention_reshape = attened_x.transpose(-2,-1).contiguous().view(B, C, H, W) + spatial_map = self.spatial_interaction(attention_reshape) + + # C-I + attened_x = attened_x * torch.sigmoid(channel_map) + # S-I + conv_x = torch.sigmoid(spatial_map) * conv_x + conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, H * W, C) + + x = attened_x + conv_x + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/ARTUnet_new.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/ARTUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9d210d46e2a4d73781b3b16b63a53fdc0f25b9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/ARTUnet_new.py @@ -0,0 +1,823 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +Unet的格式构建image restoration 网络。每个transformer block中都是dense self-attention和sparse self-attention交替连接。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True, visualize_attention_maps=False): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + self.visualize_attention_maps = visualize_attention_maps + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + # assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + # qkv: 3, B_, num_heads, N, C // num_heads + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + # B_, num_heads, N, C // num_heads * + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + attn_map = attn.detach().cpu().numpy().mean(axis=1) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) # (B * nP, nHead, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + # attn: B * nP, num_heads, N, N + # v: B * nP, num_heads, N, C // num_heads + # -attn-@-v-> B * nP, num_heads, N, C // num_heads + # -transpose-> B * nP, N, num_heads, C // num_heads + # -reshape-> B * nP, N, C + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, G ** 2, G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, Gh * Gw, Gh * Gw) + else: + x = self.attn(x, Gh, Gw, mask=attn_mask) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, attn_map + else: + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +class ConcatTransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + visualize_attention_maps=False + ): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.visualize_attention_maps = visualize_attention_maps + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True, visualize_attention_maps=visualize_attention_maps + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + ################## + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + ################## + + x = torch.cat((x, y), dim=1) + + shortcut = x + x = self.norm1(x) + + # MSA + # print("attn mask:", attn_mask) + if self.visualize_attention_maps: + x, attn_map = self.attn(x, 2 * Gh, Gw, mask=attn_mask) # x: nP*B, Gh*Gw, C attn_map: np*B, N, N + if self.ds_flag == 0: + attn_map = attn_map.reshape(B, Hd // G, Wd // G, 2 * G ** 2, 2 * G ** 2) + elif self.ds_flag == 1: + attn_map = attn_map.reshape(B, I, I, 2 * Gh * Gw, 2 * Gh * Gw) + else: + x = self.attn(x, 2 * Gh, Gw, mask=attn_mask) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + output = x + + # print(x.shape, Gh, Gw) + x = output[:, :Gh * Gw, :] + y = output[:, Gh * Gw:, :] + assert x.shape == y.shape + + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = y.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = y.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + # print("x:", x.shape) + + if self.visualize_attention_maps: + return x, y, attn_map + else: + return x, y + + def get_patches(self, x, x_size): + H, W = x_size + B, H, W, C = x.shape + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + G = I = None + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + return x, attn_mask, (Hd, Wd), (Gh, Gw), (pad_r, pad_b), G, I + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + # print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=1, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/ART_Restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/ART_Restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..4c23123f0ac1a8e82e2bf907658ce8d91951c3e4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/ART_Restormer.py @@ -0,0 +1,592 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer串联而成。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[4, 4, 4, 4], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + def forward(self, inp_img): + stack = [] + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + + ######################################### Encoder level1 ############################################ + ### ART encoder block + for layer in self.ART_encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + + ### Restormer encoder block + out_enc_level1 = rearrange(out_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = self.Restormer_encoder_level1(out_enc_level1) + # stack.append(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.ART_encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + out_enc_level2 = rearrange(out_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = self.Restormer_encoder_level2(out_enc_level2) + # stack.append(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.ART_encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + out_enc_level3 = rearrange(out_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = self.Restormer_encoder_level3(out_enc_level3) + # stack.append(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.ART_latent: + latent = layer(latent, [H // 8, W // 8]) + + ### Restormer encoder block + latent = rearrange(latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = self.Restormer_latent(latent) + # stack.append(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + out_dec_level3 = inp_dec_level3 + for layer in self.ART_decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + out_dec_level3 = rearrange(out_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + out_dec_level3 = self.Restormer_decoder_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + out_dec_level2 = inp_dec_level2 + for layer in self.ART_decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + out_dec_level2 = rearrange(out_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + out_dec_level2 = self.Restormer_decoder_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + out_dec_level1 = inp_dec_level1 + for layer in self.ART_decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + ### Restormer decoder block + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_decoder_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.Restormer_refinement(out_dec_level1) + + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/ART_Restormer_v2.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/ART_Restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..de8f921e81d79cb7af14a75326f0ddaaf3b3f4c3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/ART_Restormer_v2.py @@ -0,0 +1,663 @@ +""" +2023/07/24, Xiaohan Xing +U-shape structure, 每个block都是由ART和Restormer并联而成。 +输入特征分别经过ART和Restormer两个分支提取特征, 将两者的特征concat之后,用conv层变换得到该block的输出特征。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +from .restormer import TransformerBlock as RestormerBlock + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input of the Attention block:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + # print("output of the Attention block:", x.max(), x.min()) + return x + + +########################################################################## +class ARTBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ART_Restormer(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False + ): + + super(ART_Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.ART_encoder_level1 = nn.ModuleList([ARTBlock(dim=dim, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + self.Restormer_encoder_level1 = nn.Sequential(*[RestormerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.enc_fuse_level1 = nn.Conv2d(int(dim * 2 ** 1), dim, kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.ART_encoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_encoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.enc_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.ART_encoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_encoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.enc_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.ART_latent = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[3])]) + + self.Restormer_latent = nn.Sequential(*[RestormerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[3])]) + self.enc_fuse_level4 = nn.Conv2d(int(dim * 2 ** 4), int(dim * 2 ** 3), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.ART_decoder_level3 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[2])]) + + self.Restormer_decoder_level3 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[2])]) + self.dec_fuse_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.ART_decoder_level2 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[1])]) + + self.Restormer_decoder_level2 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[1])]) + self.dec_fuse_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + self.ART_decoder_level1 = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_art_blocks[0])]) + + self.Restormer_decoder_level1 = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_restormer_blocks[0])]) + self.dec_fuse_level1 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + + self.ART_refinement = nn.ModuleList([ARTBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate) for i in range(num_refinement_blocks)]) + + # self.Restormer_refinement = nn.Sequential(*[RestormerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + # bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + # self.fuse_refinement = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=3, stride=1, padding=1, bias=bias) + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + + # def forward(self, inp_img, aux_img): + # stack = [] + # bs, _, H, W = inp_img.shape + + # fuse_img = torch.cat((inp_img, aux_img), 1) + # inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + + def forward(self, inp_img): + bs, _, H, W = inp_img.shape + inp_enc_level1 = self.patch_embed(inp_img) # b,hw,c + + ######################################### Encoder level1 ############################################ + art_enc_level1 = inp_enc_level1 + restor_enc_level1 = inp_enc_level1 + + ### ART encoder block + for layer in self.ART_encoder_level1: + art_enc_level1 = layer(art_enc_level1, [H, W]) + + ### Restormer encoder block + restor_enc_level1 = rearrange(restor_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_enc_level1 = self.Restormer_encoder_level1(restor_enc_level1) ### (b, c, h, w) + + art_enc_level1 = rearrange(art_enc_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_enc_level1 = torch.cat((art_enc_level1, restor_enc_level1), 1) + out_enc_level1 = self.enc_fuse_level1(out_enc_level1) + out_enc_level1 = rearrange(out_enc_level1, "b c h w -> b (h w) c").contiguous() + # print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + + ######################################### Encoder level2 ############################################ + ### ART encoder block + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + art_enc_level2 = inp_enc_level2 + restor_enc_level2 = inp_enc_level2 + + for layer in self.ART_encoder_level2: + art_enc_level2 = layer(art_enc_level2, [H // 2, W // 2]) + + ### Restormer encoder block + restor_enc_level2 = rearrange(restor_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + restor_enc_level2 = self.Restormer_encoder_level2(restor_enc_level2) + + art_enc_level2 = rearrange(art_enc_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_enc_level2 = torch.cat((art_enc_level2, restor_enc_level2), 1) + out_enc_level2 = self.enc_fuse_level2(out_enc_level2) + out_enc_level2 = rearrange(out_enc_level2, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + ######################################### Encoder level3 ############################################ + ### ART encoder block + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + art_enc_level3 = inp_enc_level3 + restor_enc_level3 = inp_enc_level3 + + for layer in self.ART_encoder_level3: + art_enc_level3 = layer(art_enc_level3, [H // 4, W // 4]) + + ### Restormer encoder block + restor_enc_level3 = rearrange(restor_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + restor_enc_level3 = self.Restormer_encoder_level3(restor_enc_level3) + + art_enc_level3 = rearrange(art_enc_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_enc_level3 = torch.cat((art_enc_level3, restor_enc_level3), 1) + out_enc_level3 = self.enc_fuse_level3(out_enc_level3) + out_enc_level3 = rearrange(out_enc_level3, "b c h w -> b (h w) c").contiguous() + + # print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + + ######################################### Encoder level4 ############################################ + ### ART encoder block + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + art_latent = inp_enc_level4 + restor_latent = inp_enc_level4 + + for layer in self.ART_latent: + art_latent = layer(art_latent, [H // 8, W // 8]) + + ### Restormer encoder block + restor_latent = rearrange(restor_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + restor_latent = self.Restormer_latent(restor_latent) + + art_latent = rearrange(art_latent, "b (h w) c -> b c h w", h=H//8, w=W//8).contiguous() + latent = torch.cat((art_latent, restor_latent), 1) + latent = self.enc_fuse_level4(latent) + latent = rearrange(latent, "b c h w -> b (h w) c").contiguous() + + # print("latent feature:", latent.max(), latent.min()) + + + ######################################### Decoder level3 ############################################ + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + ### ART decoder block + art_dec_level3 = inp_dec_level3 + restor_dec_level3 = inp_dec_level3 + + for layer in self.ART_decoder_level3: + art_dec_level3 = layer(art_dec_level3, [H // 4, W // 4]) + + ### Restormer decoder block + restor_dec_level3 = rearrange(restor_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + restor_dec_level3 = self.Restormer_decoder_level3(restor_dec_level3) + + art_dec_level3 = rearrange(art_dec_level3, "b (h w) c -> b c h w", h=H//4, w=W//4).contiguous() + out_dec_level3 = torch.cat((art_dec_level3, restor_dec_level3), 1) + out_dec_level3 = self.dec_fuse_level3(out_dec_level3) + out_dec_level3 = rearrange(out_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level2 ############################################ + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + + ### ART decoder block + art_dec_level2 = inp_dec_level2 + restor_dec_level2 = inp_dec_level2 + + for layer in self.ART_decoder_level2: + art_dec_level2 = layer(art_dec_level2, [H // 2, W // 2]) + + ### Restormer decoder block + restor_dec_level2 = rearrange(restor_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + restor_dec_level2 = self.Restormer_decoder_level2(restor_dec_level2) + + art_dec_level2 = rearrange(art_dec_level2, "b (h w) c -> b c h w", h=H//2, w=W//2).contiguous() + out_dec_level2 = torch.cat((art_dec_level2, restor_dec_level2), 1) + out_dec_level2 = self.dec_fuse_level2(out_dec_level2) + out_dec_level2 = rearrange(out_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### Decoder level1 ############################################ + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + + ### ART decoder block + art_dec_level1 = inp_dec_level1 + restor_dec_level1 = inp_dec_level1 + + for layer in self.ART_decoder_level1: + art_dec_level1 = layer(art_dec_level1, [H, W]) + + ### Restormer decoder block + restor_dec_level1 = rearrange(restor_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + restor_dec_level1 = self.Restormer_decoder_level1(restor_dec_level1) + + art_dec_level1 = rearrange(art_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = torch.cat((art_dec_level1, restor_dec_level1), 1) + out_dec_level1 = self.dec_fuse_level1(out_dec_level1) + out_dec_level1 = rearrange(out_dec_level1, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + + + ######################################### final Refinement ############################################ + for layer in self.ART_refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 #, stack + + + +# def build_model(args): +# return ART_Restormer( +# inp_channels=2, +# out_channels=1, +# dim=32, +# num_art_blocks=[2, 4, 4, 6], +# num_restormer_blocks=[2, 2, 2, 2], +# num_refinement_blocks=4, +# heads=[1, 2, 4, 8], +# window_size=[10, 10, 10, 10], mlp_ratio=4., +# qkv_bias=True, qk_scale=None, +# interval=[24, 12, 6, 3], +# ffn_expansion_factor = 2.66, +# LayerNorm_type = 'WithBias', ## Other option 'BiasFree' +# bias=False) + + + +def build_model(args): + return ART_Restormer( + inp_channels=1, + out_channels=1, + dim=32, + num_art_blocks=[2, 4, 4, 6], + num_restormer_blocks=[2, 2, 2, 2], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + ffn_expansion_factor = 2.66, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + bias=False) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/ARTfuse_layer.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/ARTfuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..78af0b27528b50db897936f6b8b4eee51d566b1c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/ARTfuse_layer.py @@ -0,0 +1,705 @@ +""" +ART layer in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + # print("input to the Attention layer:", x.max(), x.min()) + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + # print("q range:", q.max(), q.min()) + # print("k range:", k.max(), k.min()) + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + # print("attention matrix:", attn.shape, attn) + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + # print("feature before proj-layer:", x.max(), x.min()) + x = self.proj(x) + # print("feature after proj-layer:", x.max(), x.min()) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pre_norm=True): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + # self.conv = nn.Conv2d(dim, dim, kernel_size=1) + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + # self.conv = ConvBlock(self.dim, self.dim, drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.bn_norm = nn.BatchNorm2d(dim) + self.pre_norm = pre_norm + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + # x = self.conv(x) + shortcut = x + if self.pre_norm: + # print("using pre_norm") + x = self.norm1(x) ## 归一化之后特征范围变得非常小 + x = x.view(B, H, W, C) + # print("normalized input range:", x.max(), x.min(), x.mean()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + # print("range of feature before layer:", x.max(), x.min(), x.mean()) + # x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + + # x_bn = self.bn_norm(x.permute(0, 3, 1, 2)) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + # print("[ART layer] input and output feature difference:", torch.mean(torch.abs(shortcut - x))) + # print("range of feature before ART layer:", shortcut.max(), shortcut.min(), shortcut.mean()) + # print("range of feature after ART layer:", x.max(), x.min(), x.mean()) + if self.pre_norm: + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + # print("using post_norm") + x = self.norm1(shortcut + self.drop_path(x)) + x = self.norm2(x + self.drop_path(self.mlp(x))) + # x = shortcut + self.drop_path(x) + # x = x + self.drop_path(self.mlp(x)) + + # print("range of fused feature:", x.max(), x.min()) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + return self.layers(image) + + + +########################################################################## +class Cross_TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + y = F.pad(y, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + y = y.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + y = y.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + y = y.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + y = y.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + + + +########################################################################## +class Cross_TransformerBlock_v2(nn.Module): + r""" ART Transformer Block. + 将两个模态的特征沿着channel方向concat, 一起取window. 之后把取出来的两个模态中同一个window的所有patches合并,送到transformer处理。 + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_x = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_y = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, y, x_size): + """ + x, y: feature maps of two modalities. They are from the same level with same feature size. + """ + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut_x = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shortcut_y = y + y = self.norm1(y) + y = y.view(B, H, W, C) + + # padding + xy = torch.cat((x, y), -1) + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + xy = F.pad(xy, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # print("pad_b and pad_r:", pad_b, pad_r) + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + xy = xy.reshape(B, Hd // G, G, Wd // G, G, 2*C).permute(0, 1, 3, 2, 4, 5).contiguous() + xy = xy.reshape(B * Hd * Wd // G ** 2, G ** 2, 2*C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + xy = xy.reshape(B, Gh, I, Gw, I, 2*C).permute(0, 2, 4, 1, 3, 5).contiguous() + xy = xy.reshape(B * I * I, Gh * Gw, 2*C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # Inside each window, fuse the x and y, then compute self-attention. + x = xy[:, :, :C] + y = xy[:, :, C:] + xy = torch.cat((x, y), 1) + xy = self.attn(xy, Gh, 2*Gw, mask=attn_mask) # nP*B, Gh*2Gw, C + # print("fused xy:", xy.shape) + + # merge embeddings + if self.ds_flag == 0: + x = xy[:, :xy.shape[1]//2, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = xy[:, :xy.shape[1]//2, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + y = xy[:, xy.shape[1]//2:, :].reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + + x = x.reshape(B, Hd, Wd, C) + y = y.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + y = y[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_x + self.drop_path(x) + x = x + self.drop_path(self.mlp_x(self.norm2(x))) + # print("x:", x.shape) + y = shortcut_y + self.drop_path(y) + y = y + self.drop_path(self.mlp_y(self.norm2(y))) + + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/DataConsistency.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/DataConsistency.py new file mode 100644 index 0000000000000000000000000000000000000000..5f460e9acabcb33e1a3f812eb9acaeb337da446c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/DataConsistency.py @@ -0,0 +1,48 @@ +""" +Created: DataConsistency @ Xiyang Cai, 2023/09/09 + +Data consistency layer for k-space signal. + +Ref: DataConsistency in DuDoRNet (https://github.com/bbbbbbzhou/DuDoRNet) + +""" + +import torch +from torch import nn +from einops import repeat + + +def data_consistency(k, k0, mask): + """ + k - input in k-space + k0 - initially sampled elements in k-space + mask - corresponding nonzero location + """ + out = (1 - mask) * k + mask * k0 + return out + + +class DataConsistency(nn.Module): + """ + Create data consistency operator + """ + + def __init__(self): + super(DataConsistency, self).__init__() + + def forward(self, k, k0, mask): + """ + k - input in frequency domain, of shape (n, nx, ny, 2) + k0 - initially sampled elements in k-space + mask - corresponding nonzero location (n, 1, len, 1) + """ + + if k.dim() != 4: # input is 2D + raise ValueError("error in data consistency layer!") + + # mask = repeat(mask.squeeze(1, 3), 'b x -> b x y c', y=k.shape[1], c=2) + mask = torch.tile(mask, (1, mask.shape[2], 1, k.shape[-1])) ### [n, 320, 320, 2] + # print("k and k0 shape:", k.shape, k0.shape) + out = data_consistency(k, k0, mask) + + return out, mask diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/DuDo_mUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/DuDo_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a56069a27f0cdd392482b41dc4e3b79252295487 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/DuDo_mUnet.py @@ -0,0 +1,359 @@ +""" +Xiaohan Xing, 2023/10/25 +传入两个模态的kspace data, 经过kspace network进行重建。 +然后把重建之后的kspace data变换回图像,然后送到image domain的网络进行图像重建。 +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + +from .Kspace_mUnet import kspace_mmUnet + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = kspace_ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = kspace_ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = kspace_ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = kspace_ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class DuDo_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 3, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + # self.kspace_net = mmConvKSpace() + self.kspace_net = kspace_mmUnet(args) + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + ) + ) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + stack = [] + feature_stack = [] + # output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + output = torch.cat((image, LR_image, aux_image), 1) #.detach() + + # print("LR image range:", LR_image.max(), LR_image.min()) + # print("kspace_recon image range:", image.max(), image.min()) + # print("aux_image range:", aux_image.max(), aux_image.min()) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output, image, recon_kspace, masked_kspace, data_stats + + + +class kspace_ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return DuDo_mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/DuDo_mUnet_ARTfusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/DuDo_mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6df67b63140ca51bd7b8664151ab4b1b7a6305d5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/DuDo_mUnet_ARTfusion.py @@ -0,0 +1,369 @@ +""" +Xiaohan Xing, 2023/11/08 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + ARTfusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + output1, output2 = aux_image.detach(), LR_image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, image, recon_kspace, masked_kspace, data_stats + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/DuDo_mUnet_CatFusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/DuDo_mUnet_CatFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c2b86b6e2031b1af07ab7cb58efc182b7a2673 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/DuDo_mUnet_CatFusion.py @@ -0,0 +1,338 @@ +""" +Xiaohan Xing, 2023/11/10 +在image domain network的前面加上kspace interpolation. +kspace用Unet做interpolation. +image domain用mUnet + multi layer concat fusion模型。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import TransformerBlock + +from .Kspace_mUnet import kspace_mmUnet +from .DataConsistency import DataConsistency +from dataloaders.BRATS_kspace_dataloader import ifft2c + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_CATfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.kspace_net = kspace_mmUnet(args) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def complex_abs(self, data): + """ + data: [B, 2, H, W] + """ + if data.size(1) == 2: + return (data ** 2).sum(dim=1).sqrt() + elif data.size(-1) == 2: + return (data ** 2).sum(dim=-1).sqrt() + + + def normalize(self, data, mean, stddev, eps=0.0): + """ + Normalize the given tensor. + Applies the formula (data - mean) / (stddev + eps). + """ + return (data - mean) / (stddev + eps) + + + def forward(self, args, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, \ + aux_image: torch.Tensor, t2_gt: torch.Tensor, t1_max: torch.Tensor, t2_max: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + if args.domain == "dual": + recon_kspace, masked_kspace = self.kspace_net(kspace, ref_kspace, mask) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("recon kspace range:", recon_kspace.max(), recon_kspace.min()) + image = self.complex_abs(ifft2c(recon_kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + + elif args.domain == "image": + image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + recon_kspace, masked_kspace = None, None + + image = image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + aux_image = aux_image * t1_max.to(torch.float32).view(-1, 1, 1, 1) + t2_image = t2_gt * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + LR_image = self.complex_abs(ifft2c(kspace.permute(0, 2, 3, 1))).unsqueeze(1) #.detach() + LR_image = LR_image * t2_max.to(torch.float32).view(-1, 1, 1, 1) + + # # print("aux_image range:", aux_image.max(), aux_image.min()) + # print("LR_image range:", LR_image.max(), LR_image.min()) + # print("kspace recon_img range:", image.max(), image.min()) + + ### normalize the input data with (x-mean)/std, and save the mean and std for recovery. + t1_mean = aux_image.view(aux_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t1_std = aux_image.view(aux_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + t2_mean = image.view(image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + t2_std = image.view(image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # t2_LR_mean = LR_image.view(LR_image.shape[0], -1).mean(dim=1).view(-1, 1, 1, 1).detach() + # t2_LR_std = LR_image.view(LR_image.shape[0], -1).std(dim=1).view(-1, 1, 1, 1).detach() + + # print("t2_kspace_recon mean and std:", t2_mean.view(-1), t2_std.view(-1)) + # print("t2_LR mean and std:", t2_LR_mean.view(-1), t2_LR_std.view(-1)) + + aux_image = self.normalize(aux_image, t1_mean, t1_std, eps=1e-11) + image = self.normalize(image, t2_mean, t2_std, eps=1e-11) + LR_image = self.normalize(LR_image, t2_mean, t2_std, eps=1e-11) + t2_gt_image = self.normalize(t2_image, t2_mean, t2_std, eps=1e-11) + + aux_image = torch.clamp(aux_image, -6, 6) + image = torch.clamp(image, -6, 6) + LR_image = torch.clamp(LR_image, -6, 6) + t2_gt_image = torch.clamp(t2_gt_image, -6, 6) + + + # print("normalized t1 range:", aux_image.max(), aux_image.min()) + # print("normalized kspace_recon_t2 range:", image.max(), image.min()) + # print("normalized LR_t2 range:", LR_image.max(), LR_image.min()) + + data_stats = {"t1_mean": t1_mean, "t1_std": t1_std, "t2_mean": t2_mean, "t2_std": t2_std} + + # output1, output2 = aux_image.detach(), torch.cat((image, LR_image), 1).detach() + # output1, output2 = aux_image.detach(), LR_image.detach() + output1, output2 = aux_image.detach(), image.detach() + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + l = 0 + bs, _, H, W = aux_image.shape + + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + recon_t1 = self.decoder1(output1, stack1) + recon_t2 = self.decoder2(output2, stack2) + + return recon_t1, recon_t2, aux_image, t2_gt_image, image, LR_image, recon_kspace, masked_kspace, data_stats + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_CATfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/Kspace_ConvNet.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/Kspace_ConvNet.py new file mode 100644 index 0000000000000000000000000000000000000000..918cbe598a62653394c8c4eaf99cd10a45c064ca --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/Kspace_ConvNet.py @@ -0,0 +1,140 @@ +""" +2023/10/05 +Xiaohan Xing +Build a simple network with several convolution layers. +Input: concat (undersampled kspace of FS-PD, fully-sampled kspace of PD modality) +Output: interpolated kspace of FS-PD +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class mmConvKSpace(nn.Module): + """ + Interpolate the kspace data of the target modality with several conv layers. + """ + def __init__( + self, args, + chans: int = 32, + drop_prob: float = 0.0): + + super().__init__() + + self.in_chans = 4 + self.out_chans = 2 + self.chans = chans + self.drop_prob = drop_prob + + self.small_conv = ConvBlock(self.in_chans, self.chans, 3, self.drop_prob) + self.mid_conv = ConvBlock(self.in_chans, self.chans, 5, self.drop_prob) + self.large_conv = ConvBlock(self.in_chans, self.chans, 7, self.drop_prob) + + self.freq_filter = nn.Sequential(nn.Conv2d(3 * self.chans, self.chans, kernel_size=1, padding=0, bias=False), + nn.ReLU(), + nn.Conv2d(self.chans, 1, kernel_size=1, padding=0, bias=False), + nn.Sigmoid()) + + self.final_conv = ConvBlock(3 * self.chans, self.out_chans, 3, self.drop_prob) + + # self.conv_blocks = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # self.conv_blocks.append(ConvBlock(self.chans, self.chans * 2, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans * 2, self.chans, drop_prob)) + # self.conv_blocks.append(ConvBlock(self.chans, self.out_chans, drop_prob)) + + self.dcs = DataConsistency() + + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + kspace: Input 4D tensor of shape `(N, H, W, 4)`. + mask: Down-sample mask `(N, 1, len, 1)` + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + k_in = torch.cat((kspace, ref_kspace), 1) + # k_in = k_in.permute(0, 3, 1, 2) + output = k_in + # print("image shape:", image.shape) + + output1 = self.small_conv(output) + output2 = self.mid_conv(output) + output3 = self.large_conv(output) + # print("output of the three conv blocks:", output1.shape, output2.shape, output3.shape) + # print("input kspace range:", kspace.max(), kspace.min()) + # print("conv1_output range:", output1.max(), output1.min()) + # print("conv2_output range:", output2.max(), output2.min()) + # print("conv3_output range:", output3.max(), output3.min()) + + output = torch.cat((output1, output2, output3), 1) + # print("output range before freq_filter:", output.max(), output.min()) + + ### Element-wise multiplication in the Frequency domain = full image size conv in the image domain. + spatial_weights = self.freq_filter(output) + # print("spatial_weights range:", spatial_weights.max(), spatial_weights.min()) + output = spatial_weights * output + # print("output range after freq_filter:", output.max(), output.min()) + + output = self.final_conv(output) + + output = output.permute(0, 2, 3, 1) + # print("output before DC layer:", output.shape) + # print("output range before DC layer:", output.max(), output.min()) + # output = output + kspace ## residual connection + + output, mask = self.dcs(output, kspace.permute(0, 2, 3, 1), mask) + mask_output = mask * output + + # mask_output = output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, kernel_size: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + kernel_size: size of the convolution kernel. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size, padding=(kernel_size-1)//2, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +def build_model(args): + return mmConvKSpace(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/Kspace_mUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/Kspace_mUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2db9a357798be4abd4cfbd44e9207555770b0456 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/Kspace_mUnet.py @@ -0,0 +1,202 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # print("output:", output.shape) + # print("kspace:", kspace.shape) + # print("mask:", mask.shape) + mask_output = mask * output + + return output.permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/Kspace_mUnet_new.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/Kspace_mUnet_new.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ccfd28dc198ffeff2259a5fbed45dabdc76819 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/Kspace_mUnet_new.py @@ -0,0 +1,200 @@ +""" +Xiaohan Xing, 2023/10/24 +For each modality, the real and imaginary parts of the kspace data are concatenated along the channel axis, +so the kspace data are in the shape of [bs, 2, 240, 240]. +We concat the kspace data from different modalities along the channel axis and input it into the UNet. +Input: [bs, 4, 240, 240], +Output: [bs, 2, 240, 240] +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F +from .DataConsistency import DataConsistency + + +class kspace_mmUnet(nn.Module): + + def __init__( + self, args, + input_dim: int = 4, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + self.dcs = DataConsistency() + + def forward(self, kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((kspace, ref_kspace), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # print("using Unet without Relu layers") + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + + # output, mask = self.dcs(output.permute(0, 2, 3, 1), kspace.permute(0, 2, 3, 1), mask) + # # print("output:", output.shape) + # # print("kspace:", kspace.shape) + # # print("mask:", mask.shape) + # mask_output = mask * output + + return output # .permute(0, 3, 1, 2), mask_output.permute(0, 3, 1, 2) + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return kspace_mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/MINet.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..4eeeb063b12be1dc594e8e076e6a59e454fcfa0b --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/MINet.py @@ -0,0 +1,356 @@ +""" +Compare with the method "Multi-contrast MRI Super-Resolution via a Multi-stage Integration Network". +The code is from https://github.com/chunmeifeng/MINet/edit/main/fastmri/models/MINet.py. +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F + + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 1 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + # print("modules_tail:", modules_tail) + # self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Sigmoid()) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + + + +class MINet(nn.Module): + def __init__(self, args, n_resgroups=3, n_resblocks=3, n_feats=64): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + # print("x1:", x1.shape, "x2:", x2.shape) + + ### 源代码中是super-resolution任务,所以self.tail对图像进行unsampling. 我们做reconstruction把这步去掉 + x2 = self.tail(x2) + + + resT1 = x1 + resT2 = x2 + + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + # print("t1s:", [item.shape for item in t1s]) + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + # print("resT1 range:", resT1.max(), resT1.min()) + # print("resT2 range:", resT2.max(), resT2.min()) + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1, x2 #x1=pd x2=pdfs + + +def build_model(args): + return MINet(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/MINet_common.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/MINet_common.py new file mode 100644 index 0000000000000000000000000000000000000000..631f3036db4edf8eab00d70c66d2058a3b65cd59 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/MINet_common.py @@ -0,0 +1,90 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias) + +class MeanShift(nn.Conv2d): + def __init__( + self, rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): + + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std + for p in self.parameters(): + p.requires_grad = False + +class BasicBlock(nn.Sequential): + def __init__( + self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, + bn=True, act=nn.ReLU(True)): + + m = [conv(in_channels, out_channels, kernel_size, bias=bias)] + if bn: + m.append(nn.BatchNorm2d(out_channels)) + if act is not None: + m.append(act) + + super(BasicBlock, self).__init__(*m) + +class ResBlock(nn.Module): + def __init__( + self, conv, n_feats, kernel_size, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + + return res + + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + # print("upsampler:", m) + + super(Upsampler, self).__init__(*m) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/SANet.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/SANet.py new file mode 100644 index 0000000000000000000000000000000000000000..bfac3590e248c9532b002885c6c0b7a55ab1400a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/SANet.py @@ -0,0 +1,392 @@ +""" +This script contains the codes of "Exploring Separable Attention for Multi-Contrast MR Image Super-Resolution (TNNLS 2023)" +https://github.com/chunmeifeng/SANet/blob/main/fastmri/models/SANet.py +""" + +from .MINet_common import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import os + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,scale,n_resgroups,n_resblocks,n_feats,conv=default_conv): + super(SR_Branch, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups #10 + self.n_resblocks = n_resblocks #20 + self.n_feats = n_feats #64 + kernel_size = 3 + reduction = 16 #16 + # scale = args.scale[0] + + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(2*n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + # self.final = nn.Sequential(nn.Conv2d(n_feats, n_colors, 3, 1, 1), nn.Tanh()) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class Seatt(nn.Module): + def __init__(self, in_c): + super(Seatt, self).__init__() + self.reduce = nn.Conv2d(in_c * 2, 32, 1) + self.ff_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.bf_conv = nn.Sequential( + nn.Conv2d(32, 32, 3, 1, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.rgbd_pred_layer = Pred_Layer(32 * 2) + self.convq = nn.Conv2d(32, 32, 3, 1, 1) + + def forward(self, rgb_feat, dep_feat, pred): + feat = torch.cat((rgb_feat, dep_feat), 1) + + # vis_attention_map_rgb_feat(rgb_feat[0][0]) + # vis_attention_map_dep_feat(dep_feat[0][0]) + # vis_attention_map_feat(feat[0][0]) + + + feat = self.reduce(feat) + [_, _, H, W] = feat.size() + pred = torch.sigmoid(pred) + + ni_pred = 1 - pred + + ff_feat = self.ff_conv(feat * pred) + bf_feat = self.bf_conv(feat * (1 - pred)) + new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1)) + + return new_pred + +class SANet(nn.Module): + def __init__(self, args, scale=1, n_resgroups=3, n_resblocks=3, n_feats=64): + super(SANet, self).__init__() + self.scale = scale + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + + cs = [64,64,64,64] + self.Seatts = nn.ModuleList([Seatt(c) for c in cs]) + + self.net1 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + scale = self.scale, + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + print("nlayer:",nlayer) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=3, padding=1) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2, Seatt, map_conv in zip(self.net1.body._modules.items(), self.net2.body._modules.items(), self.Seatts, self.map_convs): + + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + pred = map_conv(resT1) + res = Seatt(resT1,resT2,pred) + + resT2 = res+resT2 + + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + + return x1,x2 + + +def build_model(args): + return SANet(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/SwinFuse_layer.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/SwinFuse_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3dca8ba793d92961bc3f1d987e211fe542b303 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/SwinFuse_layer.py @@ -0,0 +1,1310 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + # print("x shape:", x.shape) + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + # print("num heads:", self.num_heads) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + # print("x_windows:", x_windows.shape) + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + + # # 不使用残差连接 + # x = self.residual_group_A(x, x_size) + # y = self.residual_group_B(y, x_size) + + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion_layer(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Fusion_num_heads=[6, 6], + window_size=7, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, img_range=1., resi_connection='1conv', + **kwargs): + super(SwinFusion_layer, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + + ### fusion layer. + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + # print("fusion layers:", len(self.layers_Fusion)) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + + def forward_features_Fusion(self, x, y): + + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + # x = torch.cat([x, y], 1) + # ## Downsample the feature in the channel dimension + # x = self.lrelu(self.conv_after_body_Fusion(x)) + # print("x range after fusion:", x.max(), x.min()) + # print("y range after fusion:", y.max(), y.min()) + # print("x:", x.shape) + + return x, y + + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + print("modality1 input:", x.max(), x.min()) + print("modality2 input:", y.max(), y.min()) + + # multi-modal feature fusion. + x, y = self.forward_features_Fusion(x, y) ### 返回两个模态融合之后的features. + print("modality1 output:", x.max(), x.min()) + print("modality2 output:", y.max(), y.min()) + + return x[:, :, :H, :W], y[:, :, :H, :W] + + + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + + +if __name__ == '__main__': + width, height = 120, 120 + swinfuse = SwinFusion_layer(img_size=(width, height), window_size=8, depths=[6, 6, 6], embed_dim=60, num_heads=[6, 6, 6]) + # print("SwinFusion layer:", swinfuse) + + A = torch.randn((4, 60, width, height)) + B = torch.randn((4, 60, width, height)) + + x = swinfuse(A, B) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/SwinFusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/SwinFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..3e72c2a3e8859423f3544043e8b9feaec002d0e2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/SwinFusion.py @@ -0,0 +1,1468 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# code from https://github.com/Linfeng-Tang/SwinFusion/tree/master +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + # print("qkv shape:", self.qkv(x).shape) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # print("qkv after reshape:", qkv.shape) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class Cross_WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim*2 , bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C), which maps query + y: input features with shape of (num_windows*B, N, C), which maps key and value + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = q[0], kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class Cross_SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1_A = norm_layer(dim) + self.norm1_B = norm_layer(dim) + self.attn_A = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.attn_B = Cross_WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path_A = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path_B = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2_A = norm_layer(dim) + self.norm2_B = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_A = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp_B = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, y, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut_A = x + shortcut_B = y + x = self.norm1_A(x) + y = self.norm1_B(y) + x = x.view(B, H, W, C) + y = y.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_y = torch.roll(y, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_y = y + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + y_windows = window_partition(shifted_y, self.window_size) # nW*B, window_size, window_size, C + y_windows = y_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows_A = self.attn_A(x_windows, y_windows, mask=self.calculate_mask(x_size).to(x.device)) + attn_windows_B = self.attn_B(y_windows, x_windows, mask=self.calculate_mask(x_size).to(y.device)) + + # merge windows + attn_windows_A = attn_windows_A.view(-1, self.window_size, self.window_size, C) + attn_windows_B = attn_windows_B.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows_A, self.window_size, H, W) # B H' W' C + shifted_y = window_reverse(attn_windows_B, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + y = torch.roll(shifted_y, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + y = shifted_y + x = x.view(B, H * W, C) + y = y.view(B, H * W, C) + + # FFN + x = shortcut_A + self.drop_path_A(x) + x = x + self.drop_path_A(self.mlp_A(self.norm2_A(x))) + + y = shortcut_B + self.drop_path_B(y) + y = y + self.drop_path_B(self.mlp_B(self.norm2_B(y))) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class Cross_BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + Cross_SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + # 以x为主, y只是辅助x的特征提取 + def forward(self, x, y, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x, y = checkpoint.checkpoint(blk, x, y, x_size) + else: + x, y = blk(x, y, x_size) + if self.downsample is not None: + x = self.downsample(x) + y = self.downsample(y) + # print("Cross_BasicLayer:", type(x), type(y)) + return x, y + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + # if resi_connection == '1conv': + # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, x_size): + # return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + # return self.residual_group(x, x_size) + x + return self.residual_group(x, x_size) + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class CRSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(CRSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = Cross_BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_A = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.residual_group_B = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + # if resi_connection == '1conv': + # self.conv_A = nn.Conv2d(dim, dim, 3, 1, 1) + # self.conv_B = nn.Conv2d(dim, dim, 3, 1, 1) + # elif resi_connection == '3conv': + # # to save parameters and memory + # self.conv_A = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + # self.conv_B = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + # nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + # self.patch_embed = PatchEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + # self.patch_unembed = PatchUnEmbed( + # img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + # norm_layer=None) + + def forward(self, x, y, x_size): + ## Intra-Modal Fusion + # x = self.residual_group_A(x, x_size) + x + # y = self.residual_group_B(y, x_size) + y + # 不使用残差连接 + # print("input x shape:", x.shape) + x = self.residual_group_A(x, x_size) + y = self.residual_group_B(y, x_size) + ## Inter-Modal Fusion + x1 = x + y1 = y + x, y = self.residual_group(x1, y1, x_size) + # x = self.patch_embed(self.conv_A(self.patch_unembed(x, x_size))) + x1 + # y = self.patch_embed(self.conv_B(self.patch_unembed(y, x_size))) + y1 + # 不使用残差连接 + # x = x + x1 + # y = y + y1 + return x, y + + def flops(self): + flops = 0 + flops += self.residual_group_A.flops() + flops += self.residual_group_B.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + + return flops + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinFusion(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, Ex_depths=[4], Fusion_depths=[2, 2], Re_depths=[4], + Ex_num_heads=[6], Fusion_num_heads=[6, 6], Re_num_heads=[6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinFusion, self).__init__() + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + embed_dim_temp = int(embed_dim / 2) + print('in_chans: ', in_chans) + if in_chans == 3 or in_chans == 6: + rgb_mean = (0.4488, 0.4371, 0.4040) + rgbrgb_mean = (0.4488, 0.4371, 0.4040, 0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_in = torch.Tensor(rgbrgb_mean).view(1, 6, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + ####修改shallow feature extraction 网络, 修改为2个3x3的卷积#### + self.conv_first1_A = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first1_B = nn.Conv2d(in_chans, embed_dim_temp, 3, 1, 1) + self.conv_first2_A = nn.Conv2d(embed_dim_temp, embed_dim, 3, 1, 1) + self.conv_first2_B = nn.Conv2d(embed_dim_temp, embed_dim_temp, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.Ex_num_layers = len(Ex_depths) + self.Fusion_num_layers = len(Fusion_depths) + self.Re_num_layers = len(Re_depths) + + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + self.softmax = nn.Softmax(dim=0) + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr_Ex = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Ex_depths))] # stochastic depth decay rule + dpr_Fusion = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Fusion_depths))] # stochastic depth decay rule + dpr_Re = [x.item() for x in torch.linspace(0, drop_path_rate, sum(Re_depths))] # stochastic depth decay rule + # build Residual Swin Transformer blocks (RSTB) + self.layers_Ex_A = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_A.append(layer) + self.norm_Ex_A = norm_layer(self.num_features) + + self.layers_Ex_B = nn.ModuleList() + for i_layer in range(self.Ex_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Ex_depths[i_layer], + num_heads=Ex_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Ex[sum(Ex_depths[:i_layer]):sum(Ex_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Ex_B.append(layer) + self.norm_Ex_B = norm_layer(self.num_features) + + self.layers_Fusion = nn.ModuleList() + for i_layer in range(self.Fusion_num_layers): + layer = CRSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Fusion_depths[i_layer], + num_heads=Fusion_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Fusion[sum(Fusion_depths[:i_layer]):sum(Fusion_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Fusion.append(layer) + self.norm_Fusion_A = norm_layer(self.num_features) + self.norm_Fusion_B = norm_layer(self.num_features) + + self.layers_Re = nn.ModuleList() + for i_layer in range(self.Re_num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=Re_depths[i_layer], + num_heads=Re_num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr_Re[sum(Re_depths[:i_layer]):sum(Re_depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers_Re.append(layer) + self.norm_Re = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + # self.conv_after_body_Ex_A = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Ex_B = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_after_body_Fusion = nn.Conv2d(2 * embed_dim, embed_dim, 3, 1, 1) + # self.conv_after_body_Re = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last1 = nn.Conv2d(embed_dim, embed_dim_temp, 3, 1, 1) + self.conv_last2 = nn.Conv2d(embed_dim_temp, int(embed_dim_temp/2), 3, 1, 1) + self.conv_last3 = nn.Conv2d(int(embed_dim_temp/2), num_out_ch, 3, 1, 1) + + self.tanh = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features_Ex_A(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_A: + x = layer(x, x_size) + + x = self.norm_Ex_A(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward_features_Ex_B(self, x): + x = self.lrelu(self.conv_first1_A(x)) + x = self.lrelu(self.conv_first2_A(x)) + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Ex_B: + x = layer(x, x_size) + + x = self.norm_Ex_B(x) # B L C + x = self.patch_unembed(x, x_size) + return x + + def forward_features_Fusion(self, x, y): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + y = self.patch_embed(y) + # print("x token:", x.shape, "y token:", y.shape) + if self.ape: + x = x + self.absolute_pos_embed + y = y + self.absolute_pos_embed + x = self.pos_drop(x) + y = self.pos_drop(y) + + for layer in self.layers_Fusion: + x, y = layer(x, y, x_size) + # y = layer(y, x, x_size) + + + x = self.norm_Fusion_A(x) # B L C + x = self.patch_unembed(x, x_size) + + y = self.norm_Fusion_B(y) # B L C + y = self.patch_unembed(y, x_size) + x = torch.cat([x, y], 1) + ## Downsample the feature in the channel dimension + x = self.lrelu(self.conv_after_body_Fusion(x)) + + return x + + def forward_features_Re(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_Re: + x = layer(x, x_size) + + x = self.norm_Re(x) # B L C + x = self.patch_unembed(x, x_size) + # Convolution + x = self.lrelu(self.conv_last1(x)) + x = self.lrelu(self.conv_last2(x)) + x = self.conv_last3(x) + return x + + def forward(self, A, B): + # print("Initializing the model") + x = A + y = B + H, W = x.shape[2:] + x = self.check_image_size(x) + y = self.check_image_size(y) + + # self.mean_A = self.mean.type_as(x) + # self.mean_B = self.mean.type_as(y) + # self.mean = (self.mean_A + self.mean_B) / 2 + + # x = (x - self.mean_A) * self.img_range + # y = (y - self.mean_B) * self.img_range + + # Feedforward + x = self.forward_features_Ex_A(x) + y = self.forward_features_Ex_B(y) + # print("x before fusion:", x.shape) + # print("y before fusion:", y.shape) + x = self.forward_features_Fusion(x, y) + x = self.forward_features_Re(x) + x = self.tanh(x) + + # x = x / self.img_range + self.mean + # print("H:", H, "upscale:", self.upscale) + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers_Ex_A): + flops += layer.flops() + for i, layer in enumerate(self.layers_Ex_B): + flops += layer.flops() + for i, layer in enumerate(self.layers_Fusion): + flops += layer.flops() + for i, layer in enumerate(self.layers_Re): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='convs') + + +if __name__ == '__main__': + model = SwinFusion(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='convs') + # print(model) + + A = torch.randn((1, 1, 240, 240)) + B = torch.randn((1, 1, 240, 240)) + + x = model(A, B) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/TransFuse.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/TransFuse.py new file mode 100644 index 0000000000000000000000000000000000000000..56ca3d2b24f382603c876e4f6909363a9794ffa3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/TransFuse.py @@ -0,0 +1,155 @@ +import math +from collections import deque + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from torchvision import models + + +class SelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + """ + + def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(n_embd, n_embd) + self.query = nn.Linear(n_embd, n_embd) + self.value = nn.Linear(n_embd, n_embd) + # regularization + self.attn_drop = nn.Dropout(attn_pdrop) + self.resid_drop = nn.Dropout(resid_pdrop) + # output projection + self.proj = nn.Linear(n_embd, n_embd) + self.n_head = n_head + + def forward(self, x): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y + + +class Block(nn.Module): + """ an unassuming Transformer block """ + + def __init__(self, n_embd, n_head, block_exp, attn_pdrop, resid_pdrop): + super().__init__() + self.ln1 = nn.LayerNorm(n_embd) + self.ln2 = nn.LayerNorm(n_embd) + self.attn = SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + self.mlp = nn.Sequential( + nn.Linear(n_embd, block_exp * n_embd), + nn.ReLU(True), # changed from GELU + nn.Linear(block_exp * n_embd, n_embd), + nn.Dropout(resid_pdrop), + ) + + def forward(self, x): + B, T, C = x.size() + + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + + return x + + +class TransFuse_layer(nn.Module): + """ the full GPT language model, with a context size of block_size """ + + def __init__(self, n_embd, n_head, block_exp, n_layer, + num_anchors, seq_len=1, + embd_pdrop=0.1, attn_pdrop=0.1, resid_pdrop=0.1): + super().__init__() + self.n_embd = n_embd + self.seq_len = seq_len + self.vert_anchors = num_anchors + self.horz_anchors = num_anchors + + # positional embedding parameter (learnable), image + lidar + self.pos_emb = nn.Parameter(torch.zeros(1, 2 * seq_len * self.vert_anchors * self.horz_anchors, n_embd)) + + self.drop = nn.Dropout(embd_pdrop) + + # transformer + self.blocks = nn.Sequential(*[Block(n_embd, n_head, + block_exp, attn_pdrop, resid_pdrop) + for layer in range(n_layer)]) + + # decoder head + self.ln_f = nn.LayerNorm(n_embd) + + self.block_size = seq_len + self.apply(self._init_weights) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + + def forward(self, m1, m2): + """ + Args: + m1 (tensor): B*seq_len, C, H, W + m2 (tensor): B*seq_len, C, H, W + """ + + bz = m2.shape[0] // self.seq_len + h, w = m2.shape[2:4] + + # forward the image model for token embeddings + m1 = m1.view(bz, self.seq_len, -1, h, w) + m2 = m2.view(bz, self.seq_len, -1, h, w) + + # pad token embeddings along number of tokens dimension + token_embeddings = torch.cat([m1, m2], dim=1).permute(0,1,3,4,2).contiguous() + token_embeddings = token_embeddings.view(bz, -1, self.n_embd) # (B, an * T, C) + + # add (learnable) positional embedding for all tokens + x = self.drop(self.pos_emb + token_embeddings) # (B, an * T, C) + x = self.blocks(x) # (B, an * T, C) + x = self.ln_f(x) # (B, an * T, C) + x = x.view(bz, 2 * self.seq_len, self.vert_anchors, self.horz_anchors, self.n_embd) + x = x.permute(0,1,4,2,3).contiguous() # same as token_embeddings + + m1_out = x[:, :self.seq_len, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + m2_out = x[:, self.seq_len:, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w) + print("modality1 output:", m1_out.max(), m1_out.min()) + print("modality2 output:", m2_out.max(), m2_out.min()) + + return m1_out, m2_out + + +if __name__ == "__main__": + + feature1 = torch.randn((4, 512, 20, 20)) + feature2 = torch.randn((4, 512, 20, 20)) + model = TransFuse_layer(n_embd=512, n_head=4, block_exp=4, n_layer=8, num_anchors=20, seq_len=1) + print("TransFuse_layer:", model) + feat1, feat2 = model(feature1, feature2) + print(feat1.shape, feat2.shape) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/Unet_ART.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/Unet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b0f834270a921ae4eb428c92b3e782fbed19b3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/Unet_ART.py @@ -0,0 +1,243 @@ +""" +2023/07/06, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + output = conv_layer(output) + + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..334f1821a9b59e692641f70cdc61e7655b1a9c89 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/__init__.py @@ -0,0 +1,114 @@ +from .unet import build_model as UNET +from .mmunet import build_model as MUNET +from .humus_net import build_model as HUMUS +from .MINet import build_model as MINET +from .SANet import build_model as SANET +from .mtrans_net import build_model as MTRANS +from .unimodal_transformer import build_model as TRANS +from .cnn_transformer_new import build_model as CNN_TRANS +from .cnn import build_model as CNN +from .unet_transformer import build_model as UNET_TRANSFORMER +from .swin_transformer import build_model as SWIN_TRANS +from .restormer import build_model as RESTORMER + + +from .cnn_late_fusion import build_model as MULTI_CNN +from .mmunet_late_fusion import build_model as MMUNET_LATE +from .mmunet_early_fusion import build_model as MMUNET_EARLY +from .trans_unet.trans_unet import build_model as TRANSUNET +from .SwinFusion import build_model as SWINMULTI +from .mUnet_transformer import build_model as MUNET_TRANSFORMER +from .munet_multi_transfuse import build_model as MUNET_TRANSFUSE +from .munet_swinfusion import build_model as MUNET_SWINFUSE +from .munet_multi_concat import build_model as MUNET_CONCAT +from .munet_concat_decomp import build_model as MUNET_CONCAT_DECOMP +from .munet_multi_sum import build_model as MUNET_SUM +# from .mUnet_ARTfusion_new import build_model as MUNET_ART_FUSION +from .mUnet_ARTfusion import build_model as MUNET_ART_FUSION +from .mUnet_Restormer_fusion import build_model as MUNET_RESTORMER_FUSION +from .mUnet_ARTfusion_SeqConcat import build_model as MUNET_ART_FUSION_SEQ + +from .ARTUnet import build_model as ART +from .Unet_ART import build_model as UNET_ART +from .unet_restormer import build_model as UNET_RESTORMER +from .mmunet_ART import build_model as MMUNET_ART +from .mmunet_restormer import build_model as MMUNET_RESTORMER +from .mmunet_restormer_v2 import build_model as MMUNET_RESTORMER_V2 +from .mmunet_restormer_3blocks import build_model as MMUNET_RESTORMER_SMALL + +from .mARTUnet import build_model as ART_MULTI_INPUT + +from .ART_Restormer import build_model as ART_RESTORMER +from .ART_Restormer_v2 import build_model as ART_RESTORMER_V2 +from .swinIR import build_model as SWINIR + +from .Kspace_ConvNet import build_model as KSPACE_CONVNET +from .Kspace_mUnet_new import build_model as KSPACE_MUNET +from .kspace_mUnet_concat import build_model as KSPACE_MUNET_CONCAT +from .kspace_mUnet_AttnFusion import build_model as KSPACE_MUNET_ATTNFUSE + +from .DuDo_mUnet import build_model as DUDO_MUNET +from .DuDo_mUnet_ARTfusion import build_model as DUDO_MUNET_ARTFUSION +from .DuDo_mUnet_CatFusion import build_model as DUDO_MUNET_CONCAT + + +model_factory = { + 'unet_single': UNET, + 'humus_single': HUMUS, + 'transformer_single': TRANS, + 'cnn_transformer': CNN_TRANS, + 'cnn_single': CNN, + 'swin_trans_single': SWIN_TRANS, + 'trans_unet': TRANSUNET, + 'unet_transformer': UNET_TRANSFORMER, + 'restormer': RESTORMER, + 'unet_art': UNET_ART, + 'unet_restormer': UNET_RESTORMER, + 'art': ART, + 'art_restormer': ART_RESTORMER, + 'art_restormer_v2': ART_RESTORMER_V2, + 'swinIR': SWINIR, + + + 'munet_transformer': MUNET_TRANSFORMER, + 'munet_transfuse': MUNET_TRANSFUSE, + 'cnn_late_multi': MULTI_CNN, + 'unet_multi': MUNET, + 'unet_late_multi':MMUNET_LATE, + 'unet_early_multi':MMUNET_EARLY, + 'munet_ARTfusion': MUNET_ART_FUSION, + 'munet_restormer_fusion': MUNET_RESTORMER_FUSION, + + + 'minet_multi': MINET, + 'sanet_multi': SANET, + 'mtrans_multi': MTRANS, + 'swin_fusion': SWINMULTI, + 'munet_swinfuse': MUNET_SWINFUSE, + 'munet_concat': MUNET_CONCAT, + 'munet_concat_decomp': MUNET_CONCAT_DECOMP, + 'munet_sum': MUNET_SUM, + 'mmunet_art': MMUNET_ART, + 'mmunet_restormer': MMUNET_RESTORMER, + 'mmunet_restormer_small': MMUNET_RESTORMER_SMALL, + 'mmunet_restormer_v2': MMUNET_RESTORMER_V2, + + 'art_multi_input': ART_MULTI_INPUT, + + 'munet_ARTfusion_SeqConcat': MUNET_ART_FUSION_SEQ, + + 'kspace_ConvNet': KSPACE_CONVNET, + 'kspace_munet': KSPACE_MUNET, + 'kspace_munet_concat': KSPACE_MUNET_CONCAT, + 'kspace_munet_AttnFusion': KSPACE_MUNET_ATTNFUSE, + + 'dudo_munet': DUDO_MUNET, + 'dudo_munet_ARTfusion': DUDO_MUNET_ARTFUSION, + 'dudo_munet_concat': DUDO_MUNET_CONCAT, +} + + +def build_model_from_name(args): + assert args.model_name in model_factory.keys(), 'unknown model name' + + return model_factory[args.model_name](args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/cnn.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5c130178e7cdabaaef8686da0cec341265f3c43d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/cnn.py @@ -0,0 +1,246 @@ +""" +把our method中的单模态CNN_Transformer中的transformer换成几个conv_block用来下采样提取特征, 看是否能够达到和Unet相似的效果。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Reconstructor(nn.Module): + def __init__(self): + super(CNN_Reconstructor, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + nlatent = self.encoder.output_dim + self.decoder = Decoder(n_upsample=4, n_res=1, dim=nlatent, output_dim=1, pad_type='zero') + + + def forward(self, m): + #4,1,240,240 + feature_output = self.encoder.model(m) # [4, 256, 60, 60] + # print("output of the encoder:", feature_output.shape) + + output = self.decoder(feature_output) + + return output + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Reconstructor() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/cnn_late_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/cnn_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..58dacad6910b6e896d00a0d36c9ea60e8d77217f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/cnn_late_fusion.py @@ -0,0 +1,196 @@ +""" +在lequan的代码上,去掉transformer-based feature fusion部分。 +直接用CNN提取特征之后concat作为multi-modal representation. 用普通的decoder进行图像重建。 +把T1作为guidance modality, T2作为target modality. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers1 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + self.down_sample_layers2 = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers1.append(ConvBlock(ch, ch * 2, drop_prob)) + self.down_sample_layers2.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + + # self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + # ch = chans + # for _ in range(num_pool_layers - 1): + # self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + # ch *= 2 + # self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv = nn.Sequential(nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), nn.Tanh()) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = torch.cat([image, aux_image], 1) + # for layer in self.down_sample_layers: + # output = layer(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + + # apply down-sampling layers + output = image + for layer in self.down_sample_layers1: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + output_m1 = output + + output = aux_image + for layer in self.down_sample_layers2: + output = layer(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + output_m2 = output + + # print("two modality outputs:", output_m1.shape, output.shape) + output = torch.cat([output_m1, output_m2], 1) + + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv in self.up_transpose_conv: + output = transpose_conv(output) + + output = self.up_conv(output) + # print("model output:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/cnn_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/cnn_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..261d4d214e35c3a9be5d28bc426ce7280124723c --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/cnn_transformer.py @@ -0,0 +1,289 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=2, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=2, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 60 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 4 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + feature_output = self.upsampling1(feature_output) # [4,512,30,30] + feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/cnn_transformer_new.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/cnn_transformer_new.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b2dea1e7cc8d456a684b6c839f052383999e84 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/cnn_transformer_new.py @@ -0,0 +1,290 @@ +""" +在our method中去掉多模态,直接用CNN提取图像特征,然后切patch得到patch embeddings. 送到transformer网络中进行MRI重建。 +2023/05/23: 之前是从尺寸为60*60的特征图上切patch, 现在可以尝试直接用conv_blocks提取15*15的特征图,然后按照patch_size = 1切成225个patches。 +""" + +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from ..modules import * +import numpy as np + +## sum two modality features +class CNN_Transformer(nn.Module): + def __init__(self): + super(CNN_Transformer, self).__init__() + self.n_res = 3 + + self.encoder = ContentEncoder_expand(n_downsample=4, n_res=self.n_res, input_dim=1, dim=64, norm='in', activ='relu', + pad_type='reflect') + + self.decoder = Decoder(n_upsample=4, n_res=1, dim=self.encoder.output_dim, output_dim=1, pad_type='zero') + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + patch_size = 1 + num_patch = (fmp_size // patch_size) * (fmp_size // patch_size) + patch_dim = 512*2 + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(self.encoder.output_dim, patch_dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + self.upsampling1 = nn.Sequential( + nn.ConvTranspose2d(patch_dim, patch_dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//2), + nn.ReLU(True) + ) + self.upsampling2 = nn.Sequential( + nn.ConvTranspose2d(patch_dim//2, patch_dim//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + nn.BatchNorm2d(patch_dim//4), + nn.ReLU(True) + ) + + def forward(self, m): + #4,1,240,240 + m_out = self.encoder.model(m) # [4, 256, 60, 60] + + m_embed = self.to_patch_embedding(m_out) # [4, 225, 512*2], 225 is the number of patches. + patch_embed = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + + b, n, _ = patch_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed),1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 225+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :].transpose(1,2) # [4, 1024, 225] + + h, w = int(np.sqrt(feature_output.shape[-1])), int(np.sqrt(feature_output.shape[-1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[1], h, w) # [4,512*2,15,15] + + # feature_output = self.upsampling1(feature_output) # [4,512,30,30] + # feature_output = self.upsampling2(feature_output) # [4,256,60,60] + + output = self.decoder(feature_output) + + return output + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +################################################################################## +# Encoder and Decoders +################################################################################## +class ContentEncoder_expand(nn.Module): + def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): + super(ContentEncoder_expand, self).__init__() + + self.model = [] + self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + # downsampling blocks + for i in range(n_downsample): + self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + dim *= 2 + self.model = nn.Sequential(*self.model) + + self.resblocks =[] + for i in range(n_res): + self.resblocks += [ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type)] + + self.model2 = nn.Sequential(*self.resblocks) + self.output_dim = dim + + # print("content_encoder model_1:", self.model) + # print("content_encoder model_2:", self.model2) + + def forward(self, x): + out = self.model(x) + out = self.model2(out) + return out + + +class Decoder(nn.Module): + def __init__(self, n_upsample, n_res, dim, output_dim, pad_type='zero'): + super(Decoder, self).__init__() + + norm_layer = nn.InstanceNorm2d + use_dropout = False + + self.n_res = n_res + + self.model = [] + for i in range(n_res): + self.model += [ResnetBlock(dim=dim, padding_type=pad_type, norm_layer=norm_layer, + use_dropout=use_dropout, use_bias=True)] + + for i in range(n_upsample): + self.model += [ + nn.ConvTranspose2d(dim, dim//2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + norm_layer(dim//2), + nn.ReLU(True)] + dim //= 2 + + self.model += [nn.ReflectionPad2d(3), nn.Conv2d(dim, output_dim, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*self.model) + + def forward(self, input): + return self.model(input) + + + +################################################################################## +# Basic Blocks +################################################################################## +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + + model = [] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] + model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class Conv2dBlock(nn.Module): + def __init__(self, input_dim ,output_dim, kernel_size, stride, + padding=0, norm='none', activation='relu', pad_type='zero'): + super(Conv2dBlock, self).__init__() + self.use_bias = True + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) + + def forward(self, x): + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +def build_model(args): + return CNN_Transformer() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/humus_net.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/humus_net.py new file mode 100644 index 0000000000000000000000000000000000000000..d23701634f849b414866d71b3c23c6d07f3d6437 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/humus_net.py @@ -0,0 +1,1070 @@ +""" +HUMUS-Block +Hybrid Transformer-convolutional Multi-scale denoiser. + +We use parts of the code from +SwinIR (https://github.com/JingyunLiang/SwinIR/blob/main/models/network_swinir.py) +Swin-Unet (https://github.com/HuCaoFighting/Swin-Unet/blob/main/networks/swin_transformer_unet_skip_expand_decoder_sys.py) + +HUMUS-Net中是对kspace data操作, 多个unrolling cascade, 每个cascade中都利用HUMUS-block对图像进行denoising. +因为我们的方法输入是Under-sampled images, 所以不需要进行unrolling, 直接使用一个HUMUS-block进行MRI重建。 +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from einops import rearrange + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # print("x shape:", x.shape) + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + +class PatchExpandSkip(nn.Module): + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.expand = nn.Linear(dim, 2*dim, bias=False) + self.channel_mix = nn.Linear(dim, dim // 2, bias=False) + self.norm = norm_layer(dim // 2) + + def forward(self, x, skip): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) + x = x.view(B,-1,C//4) + x = torch.cat([x, skip], dim=-1) + x = self.channel_mix(x) + x = self.norm(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, torch.tensor(x_size)) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + block_type: + D: downsampling block, + U: upsampling block + B: bottleneck block + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv', block_type='B'): + super(RSTB, self).__init__() + self.dim = dim + conv_dim = dim // (patch_size ** 2) + divide_out_ch = 1 + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + ) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(conv_dim, conv_dim // divide_out_ch, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(conv_dim, conv_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(conv_dim // 4, conv_dim // divide_out_ch, 3, 1, 1)) + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=dim, + norm_layer=None) + + self.block_type = block_type + if block_type == 'B': + self.reshape = torch.nn.Identity() + elif block_type == 'D': + self.reshape = PatchMerging(input_resolution, dim) + elif block_type == 'U': + self.reshape = PatchExpandSkip([res // 2 for res in input_resolution], dim * 2) + else: + raise ValueError('Unknown RSTB block type.') + + def forward(self, x, x_size, skip=None): + if self.block_type == 'U': + assert skip is not None, "Skip connection is required for patch expand" + x = self.reshape(x, skip) + + out = self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size)) + block_out = self.patch_embed(out) + x + + if self.block_type == 'D': + block_out = (self.reshape(block_out), block_out) # return skip connection + + return block_out + +class PatchEmbed(nn.Module): + r""" Fixed Image to Patch Embedding. Flattens image along spatial dimensions + without learned projection mapping. Only supports 1x1 patches. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + embed_dim (int): Number of output channels. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + +class PatchEmbedLearned(nn.Module): + r""" Image to Patch Embedding with arbitrary patch size and learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Linear(in_chans * patch_size[0] * patch_size[1], embed_dim) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.patch_to_vec(x) + x = self.proj(x) + if self.norm is not None: + x = self.norm(x) + return x + + def patch_to_vec(self, x): + """ + Args: + x: (B, C, H, W) + Returns: + patches: (B, num_patches, patch_size * patch_size * C) + """ + B, C, H, W = x.shape + x = x.permute(0, 2, 3, 1).contiguous() + x = x.view(B, H // self.patch_size[0], self.patch_size[0], W // self.patch_size[1], self.patch_size[1], C) + patches = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.patch_size[0], self.patch_size[1], C) + patches = patches.contiguous().view(B, self.num_patches, self.patch_size[0] * self.patch_size[1] * C) + return patches + + +class PatchUnEmbed(nn.Module): + r""" Fixed Image to Patch Unembedding via reshaping. Only supports patch size 1x1. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, embed_dim, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + +class PatchUnEmbedLearned(nn.Module): + r""" Patch Embedding to Image with learned linear projection. + Args: + img_size (int): Image size. + patch_size (int): Patch token size. + in_chans (int): Number of input image channels. + embed_dim (int): Number of linear projection output channels. + norm_layer (nn.Module, optional): Normalization layer. + """ + + def __init__(self, img_size, patch_size, in_chans=None, out_chans=None, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.out_chans = out_chans + + self.proj_up = nn.Linear(in_chans, in_chans * patch_size[0] * patch_size[1]) + + def forward(self, x, x_size=None): + B, HW, C = x.shape + x = self.proj_up(x) + x = self.vec_to_patch(x) + return x + + def vec_to_patch(self, x): + B, HW, C = x.shape + x = x.view(B, self.patches_resolution[0], self.patches_resolution[1], C).contiguous() + x = x.view(B, self.patches_resolution[0], self.patch_size[0], self.patches_resolution[1], self.patch_size[1], C // (self.patch_size[0] * self.patch_size[1])).contiguous() + x = x.view(B, self.img_size[0], self.img_size[1], -1).contiguous().permute(0, 3, 1, 2) + return x + + +class HUMUSNet(nn.Module): + r""" HUMUS-Block + Args: + img_size (int | tuple(int)): Input image size. + patch_size (int | tuple(int)): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depths (tuple(int)): Depth of each Swin Transformer layer in encoder and decoder paths. + num_heads (tuple(int)): Number of attention heads in different layers of encoder and decoder. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + ape (bool): If True, add absolute position embedding to the patch embedding. + patch_norm (bool): If True, add normalization after patch embedding. + use_checkpoint (bool): Whether to use checkpointing to save memory. + img_range: Image range. 1. or 255. + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + conv_downsample_first: use convolutional downsampling before MUST to reduce compute load on Transformers + """ + + def __init__(self, args, + img_size=[240, 240], + in_chans=1, + patch_size=1, + embed_dim=66, + depths=[2, 2, 2], + num_heads=[3, 6, 12], + window_size=5, + mlp_ratio=2., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + img_range=1., + resi_connection='1conv', + bottleneck_depth=2, + bottleneck_heads=24, + conv_downsample_first=True, + out_chans=1, + no_residual_learning=False, + **kwargs): + super(HUMUSNet, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans if out_chans is None else out_chans + + self.center_slice_out = (out_chans == 1) + + self.img_range = img_range + self.mean = torch.zeros(1, 1, 1, 1) + self.window_size = window_size + self.conv_downsample_first = conv_downsample_first + self.no_residual_learning = no_residual_learning + + ##################################################################################################### + ################################### 1, input block ################################### + input_conv_dim = embed_dim + self.conv_first = nn.Conv2d(num_in_ch, input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** self.num_layers) + self.mlp_ratio = mlp_ratio + + # Downsample for low-res feature extraction + if self.conv_downsample_first: + img_size = [im //2 for im in img_size] + self.conv_down_block = ConvBlock(input_conv_dim // 2, input_conv_dim, 0.0) + self.conv_down = DownsampConvBlock(input_conv_dim, input_conv_dim) + + # split image into non-overlapping patches + if patch_size > 1: + self.patch_embed = PatchEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + if patch_size > 1: + self.patch_unembed = PatchUnEmbedLearned( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, out_chans=embed_dim, norm_layer=norm_layer if self.patch_norm else None) + else: + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # Build MUST + # encoder + self.layers_down = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** i_layer) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='D', + ) + self.layers_down.append(layer) + + # bottleneck + dim_scaler = (2 ** self.num_layers) + self.layer_bottleneck = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=bottleneck_depth, + num_heads=bottleneck_heads, + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='B', + ) + + # decoder + self.layers_up = nn.ModuleList() + for i_layer in range(self.num_layers): + dim_scaler = (2 ** (self.num_layers - i_layer - 1)) + layer = RSTB(dim=int(embed_dim * dim_scaler), + input_resolution=(patches_resolution[0] // dim_scaler, + patches_resolution[1] // dim_scaler), + depth=depths[(self.num_layers-1-i_layer)], + num_heads=num_heads[(self.num_layers-1-i_layer)], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=False, + img_size=[im // dim_scaler for im in img_size], + patch_size=1, + resi_connection=resi_connection, + block_type='U', + ) + self.layers_up.append(layer) + + self.norm_down = norm_layer(self.num_features) + self.norm_up = norm_layer(self.embed_dim) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(input_conv_dim, input_conv_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(input_conv_dim, input_conv_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(input_conv_dim // 4, input_conv_dim, 3, 1, 1)) + + # Upsample if needed + if self.conv_downsample_first: + self.conv_up_block = ConvBlock(input_conv_dim, input_conv_dim // 2, 0.0) + self.conv_up = TransposeConvBlock(input_conv_dim, input_conv_dim // 2) + + ##################################################################################################### + ################################ 3, output block ################################ + self.conv_last = nn.Conv2d(input_conv_dim // 2 if self.conv_downsample_first else input_conv_dim, num_out_ch, 3, 1, 1) + self.output_norm = nn.Tanh() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + + # divisible by window size + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + + # divisible by total downsampling ratio + # this could be done more efficiently by combining the two + total_downsamp = int(2 ** (self.num_layers - 1)) + pad_h = (total_downsamp - h % total_downsamp) % total_downsamp + pad_w = (total_downsamp - w % total_downsamp) % total_downsamp + x = F.pad(x, (0, pad_w, 0, pad_h)) + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + # encode + skip_cons = [] + for i, layer in enumerate(self.layers_down): + x, skip = layer(x, x_size=layer.input_resolution) + skip_cons.append(skip) + x = self.norm_down(x) # B L C + + # bottleneck + x = self.layer_bottleneck(x, self.layer_bottleneck.input_resolution) + + # decode + for i, layer in enumerate(self.layers_up): + x = layer(x, x_size=layer.input_resolution, skip=skip_cons[-i-1]) + x = self.norm_up(x) + x = self.patch_unembed(x, x_size) + return x + + def forward(self, x): + # print("input shape:", x.shape) + # print("input range:", x.max(), x.min()) + C, H, W = x.shape[1:] + center_slice = (C - 1) // 2 + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.conv_downsample_first: + x_first = self.conv_first(x) + x_down = self.conv_down(self.conv_down_block(x_first)) + res = self.conv_after_body(self.forward_features(x_down)) + res = self.conv_up(res) + res = torch.cat([res, x_first], dim=1) + res = self.conv_up_block(res) + + res = self.conv_last(res) + + if self.no_residual_learning: + x = res + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + res + else: + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + + if self.no_residual_learning: + x = self.conv_last(res) + else: + if self.center_slice_out: + x = x[:, center_slice, ...].unsqueeze(1) + x = x + self.conv_last(res) + + # print("output range before scaling:", x.max(), x.min()) + # print("self.mean:", self.mean) + + if self.center_slice_out: + x = x / self.img_range + self.mean[:, center_slice, ...].unsqueeze(1) + else: + x = x / self.img_range + self.mean + + x = self.output_norm(x) + + return x[:, :, :H, :W] + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + +class DownsampConvBlock(nn.Module): + """ + A Downsampling Convolutional Block that consists of one strided convolution + layer followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.Conv2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + Returns: + Output tensor of shape `(N, out_chans, H/2, W/2)`. + """ + return self.layers(image) + + + +def build_model(args): + return HUMUSNet(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/kspace_mUnet_AttnFusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/kspace_mUnet_AttnFusion.py new file mode 100644 index 0000000000000000000000000000000000000000..02588c24559c0604ddbbd03a14b3b32e6cff2c8d --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/kspace_mUnet_AttnFusion.py @@ -0,0 +1,300 @@ +""" +Xiaohan Xing, 2023/11/22 +两个模态分别用CNN提取多个层级的特征, 每个层级都把特征沿着宽度拼接起来,得到(h, 2w, c)的特征。 +然后学习得到(h, 2w)的attention map, 和原始的特征相乘送到后面的层。 +""" + +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.attn_layers = nn.ModuleList() + + for l in range(self.num_pool_layers): + + self.attn_layers.append(nn.Sequential(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob), + nn.Conv2d(chans*(2**l), 2, kernel_size=1, stride=1), + nn.Sigmoid())) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, attn_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.attn_layers): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat之后求attention, 然后进行modulation. + attn = attn_layer(torch.cat((output1, output2), 1)) + # print("spatial attention:", attn[:, 0, :, :].unsqueeze(1).shape) + # print("output1:", output1.shape) + output1 = output1 * (1 + attn[:, 0, :, :].unsqueeze(1)) + output2 = output2 * (1 + attn[:, 1, :, :].unsqueeze(1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/kspace_mUnet_concat.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/kspace_mUnet_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..d16bfdc83305d3e1f490e5ca4b445c6d730557ce --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/kspace_mUnet_concat.py @@ -0,0 +1,351 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Predict the low-frequency mask ############## +class Mask_Predictor(nn.Module): + def __init__(self, in_chans, out_chans, chans, drop_prob): + super(Mask_Predictor, self).__init__() + self.conv1 = ConvBlock(in_chans, chans, drop_prob) + self.conv2 = ConvBlock(chans, chans*2, drop_prob) + self.conv3 = ConvBlock(chans*2, chans*4, drop_prob) + + # self.conv4 = ConvBlock(chans*4, 1, drop_prob) + + self.FC = nn.Linear(chans*4, out_chans) + self.act = nn.Sigmoid() + + + def forward(self, x): + ### 三层卷积和down-sampling. + # print("[Mask predictor] input data range:", x.max(), x.min()) + output = F.avg_pool2d(self.conv1(x), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer1_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv2(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer2_feature range:", output.max(), output.min()) + output = F.avg_pool2d(self.conv3(output), kernel_size=2, stride=2, padding=0) + # print("[Mask predictor] layer3_feature range:", output.max(), output.min()) + + feature = F.adaptive_avg_pool2d(F.relu(output), (1, 1)).squeeze(-1).squeeze(-1) + # print("mask_prediction feature:", output[:, :5]) + # print("[Mask predictor] bottleneck feature range:", feature.max(), feature.min()) + + output = self.FC(feature) + # print("output before act:", output) + output = self.act(output) + # print("mask output:", output) + + return output + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 2, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + if args.HF_refine == "True": + self.in_chans = self.in_chans * 2 + + self.mask_predictor1 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + self.mask_predictor2 = Mask_Predictor(self.in_chans, 3, self.chans, self.drop_prob) + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, kspace, ref_kspace, recon_kspace, recon_ref_kspace) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("kspace and ref_kspace:", kspace.shape, ref_kspace.shape) + # if recon_kspace is not None: + # print("recon_kspace and recon_ref_kspace:", recon_kspace.shape, recon_ref_kspace.shape) + + + # print("T2 input_kspace range:", kspace.max(), kspace.min()) + # print("T2 recon_kspace range:", recon_kspace.max(), recon_kspace.min()) + # print("T1 input_kspace range:", ref_kspace.max(), ref_kspace.min()) + # print("T1 recon_kspace range:", recon_ref_kspace.max(), recon_ref_kspace.min()) + + if recon_kspace is not None: + output1 = torch.cat((kspace, recon_kspace), 1) + output2 = torch.cat((ref_kspace, recon_ref_kspace), 1) + else: + output1, output2 = kspace, ref_kspace + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + + ### 预测做low-freq DC的mask区域. + ### 输出三个维度, 前两维是坐标,最后一维是mask内部的权重 + t1_mask = self.mask_predictor1(output1) + t2_mask = self.mask_predictor2(output2) + + # print("t1_mask and t2_mask:", t1_mask.shape, t2_mask.shape) + t1_mask_coords, t2_mask_coords = t1_mask[:, :2], t2_mask[:, :2] + t1_DC_weight, t2_DC_weight = t1_mask[:, -1], t2_mask[:, -1] + + + # ### 预测做low-freq DC的DC weight + # t1_DC_weight = self.mask_predictor1(output1) + # t2_DC_weight = self.mask_predictor2(output2) + + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, t1_mask_coords, t2_mask_coords, t1_DC_weight, t2_DC_weight ## + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mARTUnet.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mARTUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9607d41ca14562a94e542a1804af5aeaf15bc123 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mARTUnet.py @@ -0,0 +1,563 @@ +""" +This script implements the method in "ICLR 2023: ACCURATE IMAGE RESTORATION WITH ATTENTION RETRACTABLE TRANSFORMER". +Attention Retractable Transformer (ART) model for Real Image Denoising task +2023/07/20, Xiaohan Xing +将两个模态的图像在输入端concat, 得到num_channels = 2的input, 然后送到ART模型进行T2 modality的重建。 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from pdb import set_trace as stx +import numbers + +from einops import rearrange +import math + +NEG_INF = -1000000 + + +########################################################################## +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + + def forward(self, biases): + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + def flops(self, N): + flops = N * 2 * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.pos_dim + flops += N * self.pos_dim * self.num_heads + return flops + + +######################################### +class Attention(nn.Module): + r""" Multi-head self attention module with dynamic position bias. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + position_bias=True): + + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.position_bias = position_bias + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W, mask=None): + """ + Args: + x: input features with shape of (num_groups*B, N, C) + mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None + H: height of each group + W: width of each group + """ + group_size = (H, W) + B_, N, C = x.shape + assert H * W == N + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1).contiguous() # (B_, self.num_heads, N, N), N = H*W + + if self.position_bias: + # generate mother-set + position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device) + position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) # 2, 2Gh-1, 2W2-1 + biases = biases.flatten(1).transpose(0, 1).contiguous().float() # (2h-1)*(2w-1) 2 + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(group_size[0], device=attn.device) + coords_w = torch.arange(group_size[1], device=attn.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Gh, Gw + coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Gh*Gw, Gh*Gw + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Gh*Gw, Gh*Gw, 2 + relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += group_size[1] - 1 + relative_coords[:, :, 0] *= 2 * group_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw + + pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads + # select position bias + relative_position_bias = pos[relative_position_index.view(-1)].view( + group_size[0] * group_size[1], group_size[0] * group_size[1], -1) # Gh*Gw,Gh*Gw,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Gh*Gw, Gh*Gw + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nP = mask.shape[0] + attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # (B, nP, nHead, N, N) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +########################################################################## +class TransformerBlock(nn.Module): + r""" ART Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size: window size of dense attention + interval: interval size of sparse attention + ds_flag (int): use Dense Attention or Sparse Attention, 0 for DAB and 1 for SAB. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + # act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + interval=8, + ds_flag=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.interval = interval + self.ds_flag = ds_flag + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + + self.attn = Attention( + dim, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + position_bias=True) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W) + + if min(H, W) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.ds_flag = 0 + self.window_size = min(H, W) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + # print("feature after layer_norm:", x.max(), x.min()) + + # padding + size_par = self.interval if self.ds_flag == 1 else self.window_size + pad_l = pad_t = 0 + pad_r = (size_par - W % size_par) % size_par + pad_b = (size_par - H % size_par) % size_par + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hd, Wd, _ = x.shape + + mask = torch.zeros((1, Hd, Wd, 1), device=x.device) + if pad_b > 0: + mask[:, -pad_b:, :, :] = -1 + if pad_r > 0: + mask[:, :, -pad_r:, :] = -1 + + # partition the whole feature map into several groups + if self.ds_flag == 0: # Dense Attention + G = Gh = Gw = self.window_size + x = x.reshape(B, Hd // G, G, Wd // G, G, C).permute(0, 1, 3, 2, 4, 5).contiguous() + x = x.reshape(B * Hd * Wd // G ** 2, G ** 2, C) + nP = Hd * Wd // G ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Hd // G, G, Wd // G, G, 1).permute(0, 1, 3, 2, 4, 5).contiguous() + mask = mask.reshape(nP, 1, G * G) + attn_mask = torch.zeros((nP, G * G, G * G), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + if self.ds_flag == 1: # Sparse Attention + I, Gh, Gw = self.interval, Hd // self.interval, Wd // self.interval + x = x.reshape(B, Gh, I, Gw, I, C).permute(0, 2, 4, 1, 3, 5).contiguous() + x = x.reshape(B * I * I, Gh * Gw, C) + nP = I ** 2 # number of partitioning groups + # attn_mask + if pad_r > 0 or pad_b > 0: + mask = mask.reshape(1, Gh, I, Gw, I, 1).permute(0, 2, 4, 1, 3, 5).contiguous() + mask = mask.reshape(nP, 1, Gh * Gw) + attn_mask = torch.zeros((nP, Gh * Gw, Gh * Gw), device=x.device) + attn_mask = attn_mask.masked_fill(mask < 0, NEG_INF) + else: + attn_mask = None + + # MSA + # print("attn mask:", attn_mask) + x = self.attn(x, Gh, Gw, mask=attn_mask) # nP*B, Gh*Gw, C + # print("output of the Attention block:", x.max(), x.min()) + + # merge embeddings + if self.ds_flag == 0: + x = x.reshape(B, Hd // G, Wd // G, G, G, C).permute(0, 1, 3, 2, 4, + 5).contiguous() # B, Hd//G, G, Wd//G, G, C + else: + x = x.reshape(B, I, I, Gh, Gw, C).permute(0, 3, 1, 4, 2, 5).contiguous() # B, Gh, I, Gw, I, C + x = x.reshape(B, Hd, Wd, C) + + # remove padding + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + # print("x:", x.shape) + + return x + + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, ds_flag={self.ds_flag}, mlp_ratio={self.mlp_ratio}" + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x, H, W): + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W).contiguous() + x = self.body(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + return x + + +class ARTUNet(nn.Module): + def __init__(self, + inp_channels=2, + out_channels=1, + dim=48, + num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[8, 8, 8, 8], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., + interval=[32, 16, 8, 4], + bias=False, + dual_pixel_task=False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(ARTUNet, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.ModuleList([ + TransformerBlock(dim=dim, + num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 + self.latent = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], window_size=window_size[3], + interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], window_size=window_size[2], + interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], window_size=window_size[1], + interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_blocks[0])]) + + self.refinement = nn.ModuleList([ + TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], window_size=window_size[0], + interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + ) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2 ** 1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img, aux_img): + stack = [] + bs, _, H, W = inp_img.shape + fuse_img = torch.cat((inp_img, aux_img), 1) + inp_enc_level1 = self.patch_embed(fuse_img) # b,hw,c + out_enc_level1 = inp_enc_level1 + for layer in self.encoder_level1: + out_enc_level1 = layer(out_enc_level1, [H, W]) + stack.append(out_enc_level1.view(bs, H, W, -1).permute(0, 3, 1, 2)) + print("out_enc_level1:", out_enc_level1.max(), out_enc_level1.min()) + + inp_enc_level2 = self.down1_2(out_enc_level1, H, W) # b, hw//4, 2c + out_enc_level2 = inp_enc_level2 + for layer in self.encoder_level2: + out_enc_level2 = layer(out_enc_level2, [H // 2, W // 2]) + stack.append(out_enc_level2.view(bs, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + print("out_enc_level2:", out_enc_level2.max(), out_enc_level2.min()) + + inp_enc_level3 = self.down2_3(out_enc_level2, H // 2, W // 2) # b, hw//16, 4c + out_enc_level3 = inp_enc_level3 + for layer in self.encoder_level3: + out_enc_level3 = layer(out_enc_level3, [H // 4, W // 4]) + stack.append(out_enc_level3.view(bs, H // 4, W // 4, -1).permute(0, 3, 1, 2)) + print("out_enc_level3:", out_enc_level3.max(), out_enc_level3.min()) + + inp_enc_level4 = self.down3_4(out_enc_level3, H // 4, W // 4) # b, hw//64, 8c + latent = inp_enc_level4 + for layer in self.latent: + latent = layer(latent, [H // 8, W // 8]) + stack.append(latent.view(bs, H // 8, W // 8, -1).permute(0, 3, 1, 2)) + print("latent feature:", latent.max(), latent.min()) + + inp_dec_level3 = self.up4_3(latent, H // 8, W // 8) # b, hw//16, 4c + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 2) + inp_dec_level3 = rearrange(inp_dec_level3, "b (h w) c -> b c h w", h=H // 4, w=W // 4).contiguous() + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + inp_dec_level3 = rearrange(inp_dec_level3, "b c h w -> b (h w) c").contiguous() # b, hw//16, 4c + out_dec_level3 = inp_dec_level3 + for layer in self.decoder_level3: + out_dec_level3 = layer(out_dec_level3, [H // 4, W // 4]) + + inp_dec_level2 = self.up3_2(out_dec_level3, H // 4, W // 4) # b, hw//4, 2c + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 2) + inp_dec_level2 = rearrange(inp_dec_level2, "b (h w) c -> b c h w", h=H // 2, w=W // 2).contiguous() + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + inp_dec_level2 = rearrange(inp_dec_level2, "b c h w -> b (h w) c").contiguous() # b, hw//4, 2c + out_dec_level2 = inp_dec_level2 + for layer in self.decoder_level2: + out_dec_level2 = layer(out_dec_level2, [H // 2, W // 2]) + + inp_dec_level1 = self.up2_1(out_dec_level2, H // 2, W // 2) # b, hw, c + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 2) + out_dec_level1 = inp_dec_level1 + for layer in self.decoder_level1: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + for layer in self.refinement: + out_dec_level1 = layer(out_dec_level1, [H, W]) + + out_dec_level1 = rearrange(out_dec_level1, "b (h w) c -> b c h w", h=H, w=W).contiguous() + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + return out_dec_level1 #, stack + + + +def build_model(args): + return ARTUNet( + inp_channels=2, + out_channels=1, + dim=32, + num_blocks=[4, 4, 4, 4], + # dim=48, + # num_blocks=[4, 6, 6, 8], + num_refinement_blocks=4, + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], mlp_ratio=4., + qkv_bias=True, qk_scale=None, + interval=[24, 12, 6, 3], + bias=False, + dual_pixel_task=False + ) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mUnet_ARTfusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mUnet_ARTfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..610b1e1e38d74f695a3a19f22a6a80f3152bf713 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mUnet_ARTfusion.py @@ -0,0 +1,305 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet import CS_TransformerBlock as TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4.): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + # output1, output2 = image1, image2 + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + + # print("range of Unet feature:", output1.max(), output1.min()) + # print("range of Unet feature:", output2.max(), output2.min()) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + output = layer(output, [H // (2**l), W // (2**l)]) + + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + t1_features.append(output1) + t2_features.append(output2) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mUnet_ARTfusion_SeqConcat.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mUnet_ARTfusion_SeqConcat.py new file mode 100644 index 0000000000000000000000000000000000000000..55280925415f1db47cf746b59fb9710d47c9bff2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mUnet_ARTfusion_SeqConcat.py @@ -0,0 +1,374 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any + +import matplotlib.pyplot as plt +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .ARTUnet_new import ConcatTransformerBlock, TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion_SeqConcat(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8], + window_size=[10, 10, 10, 10], + interval=[24, 12, 6, 3], + mlp_ratio=4. + ): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.vis_feature_maps = False + self.vis_attention_maps = False + # self.vis_feature_maps = args.MODEL.VIS_FEAT_MAPS + # self.vis_attention_maps = args.MODEL.VIS_ATTN_MAPS + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + if l < 2: + self.fuse_layers.append(nn.ModuleList([TransformerBlock(dim=int(chans * (2 ** (l + 1))), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + else: + self.fuse_layers.append(nn.ModuleList([ConcatTransformerBlock(dim=int(chans * (2 ** l)), num_heads=heads[l], + window_size=window_size[l], interval=interval[l], ds_flag=0 if i % 2 == 0 else 1, + mlp_ratio=mlp_ratio, qkv_bias=True, qk_scale=None, drop=0., + attn_drop=0., drop_path=0.,visualize_attention_maps=self.vis_attention_maps) for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + feature_maps = { + 'modal1': {'before': [], 'after': []}, + 'modal2': {'before': [], 'after': []} + } + + attention_maps = [] + """ + attention_maps = [ + unet layer 1 attention maps (x2): [map from TransBlock1, TransBlock2, ...], + unet layer 2 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 3 attention maps (x4): [map from TransBlock1, TransBlock2, ...], + unet layer 4 attention maps (x6): [map from TransBlock1, TransBlock2, ...], + ] + """ + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + # print("number of blocks in fuse layer:", len(fuse_layer)) + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + if self.vis_feature_maps: + feature_maps['modal1']['before'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['before'].append(output2.detach().cpu().numpy()) + + ### 两个模态的特征concat + # output = torch.cat((output1, output2), 1) + + + + unet_layer_attn_maps = [] + """ + unet_layer_attn_maps: [ + attn_map from TransformerBlock1: (B, nHGroups, nWGroups, winSize, winSize), + attn_map from TransformerBlock2: (B, nHGroups, nWGroups, winSize, winSize), + ... + ] + """ + if l < 2: + output = torch.cat((output1, output2), 1) + output = rearrange(output, "b c h w -> b (h w) c").contiguous() + else: + output1 = rearrange(output1, "b c h w -> b (h w) c").contiguous() + output2 = rearrange(output2, "b c h w -> b (h w) c").contiguous() + + for layer in fuse_layer: + if l < 2: + if self.vis_attention_maps: + output, attn_map = layer(output, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output = layer(output, [H // (2**l), W // (2**l)]) + + else: + if self.vis_attention_maps: + output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) + unet_layer_attn_maps.append(attn_map) + else: + output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + attention_maps.append(unet_layer_attn_maps) + + if l < 2: + output = rearrange(output, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + else: + output1 = rearrange(output1, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + output2 = rearrange(output2, "b (h w) c -> b c h w", h=H // (2**l), w=W // (2**l)).contiguous() + + # if self.vis_attention_maps: + # output1, output2, attn_map = layer(output1, output2, [H // (2**l), W // (2**l)]) # attn_map: (B, nHGroups, nWGroups, winSize, winSize) + # unet_layer_attn_maps.append(attn_map) + # else: + # output1, output2 = layer(output1, output2, [H // (2**l), W // (2**l)]) + + # attention_maps.append(unet_layer_attn_maps) + + + + if self.vis_feature_maps: + feature_maps['modal1']['after'].append(output1.detach().cpu().numpy()) + feature_maps['modal2']['after'].append(output2.detach().cpu().numpy()) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + if self.vis_feature_maps: + return output1, output2, feature_maps#, relation_stack1, relation_stack2 #, t1_features, t2_features + elif self.vis_attention_maps: + return output1, output2, attention_maps + else: + return output1, output2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion_SeqConcat(args) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mUnet_Restormer_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mUnet_Restormer_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8ab29b641fa0231ad0d163224a4a90b44c32a9 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mUnet_Restormer_fusion.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/07/26 +两个模态分别用Unet提取多个层级的特征, 每个层级都用ART block融合多模态特征. +将fused feature送到encoder的下一个层级和decoder. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from .restormer import TransformerBlock + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_ARTfusion(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过ART block融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + num_blocks=[2, 4, 4, 6], + heads=[1, 2, 4, 8]): + + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + + self.fuse_layers = nn.ModuleList() + for l in range(self.num_pool_layers): + self.fuse_layers.append(nn.Sequential(*[TransformerBlock(dim=int(chans * 2 ** (l+1)), num_heads=heads[0], ffn_expansion_factor=2.66, \ + bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[l])])) + + # print(self.fuse_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + l = 0 + bs, _, H, W = image1.shape + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, fuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_layers): + + ### 将encoder中multi-modal fusion之后的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + stack1.append(output1) + stack2.append(output2) + + ### 两个模态的特征concat + output = torch.cat((output1, output2), 1) + output = fuse_layer(output) + + output1 = output[:, :output.shape[1]//2, :, :] + output2 = output[:, output.shape[1]//2:, :, :] + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + l += 1 + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_ARTfusion(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mUnet_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mUnet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2a209758203b198502154b75496333ef91717a --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mUnet_transformer.py @@ -0,0 +1,387 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/06/08 +分别用CNN提取两个模态的特征,然后送到transformer进行特征交互, 将经过transformer提取的特征和之前CNN提取的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Decoder_wo_skip(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder_wo_skip, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + output = transpose_conv(output) + output = conv(output) + + return output + + + + + +class mUnet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch*2 + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output1, output2 = image1, image2 + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack1, output1 = self.encoder1(output1) ### output size: [4, 256, 15, 15] + stack2, output2 = self.encoder2(output2) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output1) # [4, 225, 256], 225 is the number of patches. + patch_embed_input1 = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + n_embed = self.to_patch_embedding(output2) # [4, 225, 256], 225 is the number of patches. + patch_embed_input2 = F.normalize(n_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input1, patch_embed_input2), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1]/2)), int(np.sqrt(feature_output.shape[1]/2)) + patch_num = feature_output.shape[1] + feature_output1 = feature_output[:, :patch_num//2, :] + feature_output2 = feature_output[:, patch_num//2:, :] + feature_output1 = feature_output1.contiguous().view(b, feature_output1.shape[-1], h, w) # [4,c,15,15] + feature_output2 = feature_output2.contiguous().view(b, feature_output2.shape[-1], h, w) + + output = torch.cat((output1, output2), 1) + torch.cat((feature_output1, feature_output2), 1) + # print("CNN feature range:", output1.max(), output1.min()) + # print("Transformer feature range:", feature_output1.max(), feature_output1.min()) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_Transformer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ee0c4784d04638df4ca259613777dde34980a2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet.py @@ -0,0 +1,201 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/04 +Concatenate the input images from different modalities and input it into the UNet. +""" +# coding: utf-8 +import torch +from torch import nn +from torch.nn import functional as F + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("down-sampling layer output:", output.shape) + + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # print("feature range of this layer:", image.max(), image.min()) + # return image + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_ART.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_ART.py new file mode 100644 index 0000000000000000000000000000000000000000..1a09f2900a8924cdc1a7c373270e6ff8aadcfa45 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_ART.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1, pre_norm=False) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_ART_v2.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_ART_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..adfe6bc0d9192a126f39932b40443badb5b60068 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_ART_v2.py @@ -0,0 +1,256 @@ +""" +2023/07/06, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个ART transformer block提取long-range dependent feature. +""" +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .ARTfuse_layer import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + + self.transform_layers = nn.Sequential(nn.Conv2d(self.chans, self.chans, kernel_size=1), + nn.Conv2d(self.chans*2, self.chans*2, kernel_size=1), + nn.Conv2d(self.chans*4, self.chans*4, kernel_size=1), + nn.Conv2d(self.chans*8, self.chans*8, kernel_size=1)) + + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + # ART transformer fusion modules + # num_blocks=[4, 6, 6, 8] + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + window_size=[10, 10, 10, 10] + interval=[24, 12, 6, 3] + + self.ART_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], window_size=window_size[0], interval=interval[0], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[0])]) + self.ART_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], window_size=window_size[1], interval=interval[1], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[1])]) + self.ART_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], window_size=window_size[2], interval=interval[2], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[2])]) + self.ART_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], window_size=window_size[3], interval=interval[3], + ds_flag=0 if i % 2 == 0 else 1) for i in range(num_blocks[3])]) + self.ART_blocks = nn.Sequential(self.ART_block1, self.ART_block2, self.ART_block3, self.ART_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, art_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + l = 0 + for layer, conv_layer, art_block in zip(self.down_sample_layers, self.transform_layers, self.ART_blocks): + output = layer(output) + # print("Unet feature range:", output.shape, output.max(), output.min()) + output = conv_layer(output) + # print("feature after conv_transform layer:", output.shape, output.max(), output.min()) + # feature1, feature2 = artfuse_layer(output1, output2) + feature = output.view(output.shape[0], output.shape[1], -1).permute(0, 2, 1) ## [bs, h*w, dim] + for art_layer in art_block: + # feature1, feature2 = artfuse_layer(feature1, feature2, [self.img_size//2**layer, self.img_size//2**layer]) + feature = art_layer(feature, [self.img_size//2**l, self.img_size//2**l]) + # print("intermediate feature1:", feature1.shape) + + feature = feature.view(output.shape[0], output.shape[2], output.shape[2], output.shape[1]).permute(0, 3, 1, 2) + unet_features_stack.append(output) + art_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + l+=1 + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, art_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_early_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_early_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..d062f057a2080a471005125a00ee27453a4b0706 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_early_fusion.py @@ -0,0 +1,223 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +encapsulate the encoder and decoder for the Unet model. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mmUnet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = torch.cat((image, aux_image), 1) ### shape: (N, in_chans, H, W)`. + # print("image shape:", image.shape) + + stack, output = self.encoder(output) + output = self.conv(output) ### from [4, 256, 15, 15] to [4, 512, 15, 15] + output = self.decoder(output, stack) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_late_fusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_late_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..2b48e7d4445afd6a5bc1eaa2b065c37431585e94 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_late_fusion.py @@ -0,0 +1,229 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +use Unet for the reconstruction of each modality, concat the intermediate features of both modalities. +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +from torch.nn import functional as F + + + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("decoder output:", output.shape) + + return output + + + + +class mmUnet_late(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + self.conv = ConvBlock(ch * 2, ch * 2, drop_prob) + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack1, output1 = self.encoder1(image) + stack2, output2 = self.encoder2(aux_image) + ### fuse two modalities to get multi-modal representation. + output = torch.cat([output1, output2], 1) + output = self.conv(output) ### from [4, 512, 15, 15] to [4, 512, 15, 15] + + output1 = self.decoder1(output, stack1) + output2 = self.decoder2(output, stack2) + + return output1, output2 + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mmUnet_late(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..88c7920d65c58ab82a00309a1ae501a8d5b33804 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_restormer.py @@ -0,0 +1,237 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_restormer_3blocks.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_restormer_3blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e3afe68612727d3195872a121ec8040fb0f66ef2 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_restormer_3blocks.py @@ -0,0 +1,235 @@ +""" +2023/07/18, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +之前 3 blocks的模型深层特征存在严重的gradient vanishing问题。现在把模型的encoder和decoder都改成只有3个blocks, 看是否还存在gradient vanishing问题。 +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 3, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2] + heads=[1, 2, 4] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \ + bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("encoder features:", feature.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # print("Unet feature range:", output.max(), output.min()) + # print("Restormer feature range:", feature.max(), feature.min()) + + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer feature:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_restormer_v2.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_restormer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..4b51d3a218057206863230973d557173e891d3cd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mmunet_restormer_v2.py @@ -0,0 +1,220 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +这个版本中是将Unet encoder提取的特征直接连接到decoder, 经过restormer refine的特征只是传递到下一层的encoder block. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_ART(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias') for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor, aux_image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + # output = image + output = torch.cat((image, aux_image), 1) + # print("image shape:", image.shape) + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + output = feature + + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_ART(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mtrans_mca.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mtrans_mca.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4c89df4af71e986fa7c24d522e6732371f3128 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mtrans_mca.py @@ -0,0 +1,119 @@ +import torch +from torch import nn + +from .mtrans_transformer import Mlp + +# implement by YunluYan + + +class LinearProject(nn.Module): + def __init__(self, dim_in, dim_out, fn): + super(LinearProject, self).__init__() + need_projection = dim_in != dim_out + self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity() + self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity() + self.fn = fn + + def forward(self, x, *args, **kwargs): + x = self.project_in(x) + x = self.fn(x, *args, **kwargs) + x = self.project_out(x) + return x + + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, dim, num_heads, attn_drop=0., proj_drop=0.): + super(MultiHeadCrossAttention, self).__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim, dim*2, bias=False) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + + def forward(self, x, complement): + + # x [B, HW, C] + B_x, N_x, C_x = x.shape + + x_copy = x + + complement = torch.cat([x, complement], 1) + + B_c, N_c, C_c = complement.shape + + # q [B, heads, HW, C//num_heads] + q = self.to_q(x).reshape(B_x, N_x, self.num_heads, C_x//self.num_heads).permute(0, 2, 1, 3) + kv = self.to_kv(complement).reshape(B_c, N_c, 2, self.num_heads, C_c//self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_x, N_x, C_x) + + x = x + x_copy + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossTransformerEncoderLayer(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio = 1., attn_drop=0., proj_drop=0.,drop_path = 0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super(CrossTransformerEncoderLayer, self).__init__() + self.x_norm1 = norm_layer(dim) + self.c_norm1 = norm_layer(dim) + + self.attn = MultiHeadCrossAttention(dim, num_heads, attn_drop, proj_drop) + + self.x_norm2 = norm_layer(dim) + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop) + + self.drop1 = nn.Dropout(drop_path) + self.drop2 = nn.Dropout(drop_path) + + def forward(self, x, complement): + x = self.x_norm1(x) + complement = self.c_norm1(complement) + + x = x + self.drop1(self.attn(x, complement)) + x = x + self.drop2(self.mlp(self.x_norm2(x))) + return x + + + +class CrossTransformer(nn.Module): + def __init__(self, x_dim, c_dim, depth, num_heads, mlp_ratio =1., attn_drop=0., proj_drop=0., drop_path =0.): + super(CrossTransformer, self).__init__() + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + LinearProject(x_dim, c_dim, CrossTransformerEncoderLayer(c_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)), + LinearProject(c_dim, x_dim, CrossTransformerEncoderLayer(x_dim, num_heads, mlp_ratio, attn_drop, proj_drop, drop_path)) + ])) + + def forward(self, x, complement): + for x_attn_complemnt, complement_attn_x in self.layers: + x = x_attn_complemnt(x, complement=complement) + x + complement = complement_attn_x(complement, complement=x) + complement + return x, complement + + + + + + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mtrans_net.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mtrans_net.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e80dadb45725603c7ff1da664a45533575db25 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mtrans_net.py @@ -0,0 +1,145 @@ +""" +This script implement the paper "Multi-Modal Transformer for Accelerated MR Imaging (TMI 2022)" +""" + + +import torch + +from torch import nn +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler +from .mtrans_mca import CrossTransformer + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(INPUT_DIM, HEAD_HIDDEN_DIM): + return ReconstructionHead(INPUT_DIM, HEAD_HIDDEN_DIM) + + +class CrossCMMT(nn.Module): + + def __init__(self, args): + super(CrossCMMT, self).__init__() + + INPUT_SIZE = 240 + INPUT_DIM = 1 # the channel of input + OUTPUT_DIM = 1 # the channel of output + HEAD_HIDDEN_DIM = 16 # the hidden dim of Head + TRANSFORMER_DEPTH = 4 # the depth of the transformer + TRANSFORMER_NUM_HEADS = 4 # the head's num of multi head attention + TRANSFORMER_MLP_RATIO = 3 # the MLP RATIO Of transformer + TRANSFORMER_EMBED_DIM = 256 # the EMBED DIM of transformer + P1 = 8 + P2 = 16 + CTDEPTH = 4 + + + self.head = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + self.head2 = build_head(INPUT_DIM, HEAD_HIDDEN_DIM) + + x_patch_dim = HEAD_HIDDEN_DIM * P1 ** 2 + x_num_patches = (INPUT_SIZE // P1) ** 2 + + complement_patch_dim = HEAD_HIDDEN_DIM * P2 ** 2 + complement_num_patches = (INPUT_SIZE // P2) ** 2 + + self.x_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P1, + p2=P1), + ) + + self.complement_patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=P2, + p2=P2), + ) + + self.x_pos_embedding = nn.Parameter(torch.randn(1, x_num_patches, x_patch_dim)) + self.complement_pos_embedding = nn.Parameter(torch.randn(1, complement_num_patches, complement_patch_dim)) + + ### + self.cross_transformer = CrossTransformer(x_patch_dim, complement_patch_dim, CTDEPTH, + TRANSFORMER_NUM_HEADS, TRANSFORMER_MLP_RATIO) + + self.p1 = P1 + self.p2 = P2 + + self.tail1 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + self.tail2 = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x, complement): + x = self.head(x) + complement = self.head2(complement) + + x = x + complement + + b, _, h, w = x.shape + + _, _, c_h, c_w = complement.shape + + ### 得到每个模态的patch embedding (加上了position encoding) + x = self.x_patch_embbeding(x) + x += self.x_pos_embedding + + complement = self.complement_patch_embbeding(complement) + complement += self.complement_pos_embedding + + ### 将两个模态的patch embeddings送到cross transformer中提取特征。 + x, complement = self.cross_transformer(x, complement) + + c = int(x.shape[2] / (self.p1 * self.p1)) + H = int(h / self.p1) + W = int(w / self.p1) + + x = x.reshape(b, H, W, self.p1, self.p1, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w, ) + x = self.tail1(x) + + complement = complement.reshape(b, int(c_h/self.p2), int(c_w/self.p2), self.p2, self.p2, int(complement.shape[2]/self.p2/self.p2)) + complement = complement.permute(0, 5, 1, 3, 2, 4) + complement = complement.reshape(b, -1, c_h, c_w) + + complement = self.tail2(complement) + + return x, complement + + +def build_model(args): + return CrossCMMT(args) + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mtrans_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mtrans_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..61314a6a81247b0740a1e98d078242b26ce73f90 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/mtrans_transformer.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(args, embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=args.MODEL.TRANSFORMER_DEPTH, + num_heads=args.MODEL.TRANSFORMER_NUM_HEADS, + mlp_ratio=args.MODEL.TRANSFORMER_MLP_RATIO) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/munet_concat_decomp.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/munet_concat_decomp.py new file mode 100644 index 0000000000000000000000000000000000000000..07c6f64804dd586927814801b167163f0b65f58f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/munet_concat_decomp.py @@ -0,0 +1,295 @@ +""" +Xiaohan Xing, 2023/06/25 +两个模态分别用CNN提取多个层级的特征, 每个层级的特征都经过concat之后得到fused feature, +然后每个模态都通过一层conv变换得到2C*h*w的特征, 其中C个channels作为common feature, 另C个channels作为specific features. +利用CDDFuse论文中的decomposition loss约束特征解耦。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.transform_layers1 = nn.ModuleList() + self.transform_layers2 = nn.ModuleList() + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.transform_layers1.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.transform_layers2.append(ConvBlock(chans*(2**l), chans*(2**l)*2, drop_prob)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**l)*3, chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + stack1_common, stack1_specific, stack2_common, stack2_specific = [], [], [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_transform_layer, net2_transform_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.transform_layers1, self.transform_layers2, \ + self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + # print("original feature:", output1.shape) + + ### 分别提取两个模态的common feature和specific feature + output1 = net1_transform_layer(output1) + # print("transformed feature:", output1.shape) + output1_common = output1[:, :(output1.shape[1]//2), :, :] + output1_specific = output1[:, (output1.shape[1]//2):, :, :] + output2 = net2_transform_layer(output2) + output2_common = output2[:, :(output2.shape[1]//2), :, :] + output2_specific = output2[:, (output2.shape[1]//2):, :, :] + + stack1_common.append(output1_common) + stack1_specific.append(output1_specific) + stack2_common.append(output2_common) + stack2_specific.append(output2_specific) + + ### 将一个模态的两组feature和另一模态的common feature合并, 然后经过conv变换得到和初始特征相同维度的fused feature送到下一层。 + output1 = net1_fuse_layer(torch.cat((output1_common, output1_specific, output2_common), 1)) + output2 = net2_fuse_layer(torch.cat((output2_common, output2_specific, output1_common), 1)) + # print("fused feature:", output1.shape) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2, stack1_common, stack1_specific, stack2_common, stack2_specific + #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/munet_multi_concat.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/munet_multi_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..9e90373e1134210809b98fdc64e3d51d76123855 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/munet_multi_concat.py @@ -0,0 +1,306 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 2, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+1)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor, \ + image1_krecon: torch.Tensor, image2_krecon: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + # output1, output2 = image1, image2 + output1 = torch.cat((image1, image1_krecon), 1) + output2 = torch.cat((image2, image2_krecon), 1) + + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, net1_fuse_layer, net2_fuse_layer in zip( + self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # relation1 = get_relation_matrix(output1) + # relation2 = get_relation_matrix(output2) + # relation_stack1.append(relation1) + # relation_stack2.append(relation2) + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # print("range of output1:", output1.max(), output1.min()) + # print("range of output2:", output2.max(), output2.min()) + + + # ### 将encoder中multi-modal fusion之前的特征通过skip connection连接到decoder. + # output1 = net1_layer(output1) + # output2 = net2_layer(output2) + + # ### 两个模态的特征concat + # fused_output1 = net1_fuse_layer(torch.cat((output1, output2), 1)) + # fused_output2 = net2_fuse_layer(torch.cat((output1, output2), 1)) + + # stack1.append(fused_output1) + # stack2.append(fused_output2) + + # ### downsampling features + # output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + # output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, relation_stack1, relation_stack2 #, t1_features, t2_features + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/munet_multi_sum.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/munet_multi_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb4efb13b54ccbeaeb34fc51d3e72b0c50ee242 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/munet_multi_sum.py @@ -0,0 +1,270 @@ +""" +Xiaohan Xing, 2023/06/19 +两个模态分别用CNN提取多个层级的特征, 每个层级都通过sum融合多模态特征。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_multi_fuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过求平均的方式融合各层特征。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + t1_features, t2_features = [], [] + relation_stack1, relation_stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers): + + output1 = net1_layer(output1) + stack1.append(output1) + t1_features.append(output1) + + output2 = net2_layer(output2) + stack2.append(output2) + t2_features.append(output2) + + ## 两个模态的特征求平均. + fused_output1 = (output1 + output2)/2.0 + fused_output2 = (output1 + output2)/2.0 + + output1 = F.avg_pool2d(fused_output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(fused_output2, kernel_size=2, stride=2, padding=0) + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 #, t1_features, t2_features #, relation_stack1, relation_stack2 + + + +def get_relation_matrix(feature): + """ + 将各层的特征都变换成5*5的尺寸, 然后计算25*25个位置之间的relation matrix. + """ + bs, c, h, w = feature.shape + feature = feature.view(bs, c, h//15, 15, w//15, 15).permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, -1, 15, 15) + avg_pool = nn.AdaptiveAvgPool2d(5) + feature = avg_pool(feature).view(bs, feature.shape[1], -1).permute(0, 2, 1) ### (bs, 5*5, c) + # print("intermediate feature:", feature.shape) + + feature_norm = torch.norm(feature, p=2, dim=-1, keepdim=True) ### (bs, 5*5, 1) + relation_matrix = torch.bmm(feature, feature.permute(0, 2, 1))/torch.bmm(feature_norm, feature_norm.permute(0, 2, 1)) # (bs, 25, 25) + + return relation_matrix + + + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_multi_fuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/munet_multi_transfuse.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/munet_multi_transfuse.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ce250a078ba0847c729c1ef1b76452f64d0f07 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/munet_multi_transfuse.py @@ -0,0 +1,279 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +2023/06/23 +每个层级的特征用TransFuse layer融合之后, 将两个模态的original features和transfuse features concat, +然后在每个模态中经过conv层变换得到下一层的输入。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .TransFuse import TransFuse_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_TransFuse(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + ### conv layer to transform the features after Transfuse layer. + self.fuse_conv_layers1 = nn.ModuleList() + self.fuse_conv_layers2 = nn.ModuleList() + for l in range(self.num_pool_layers): + # print("input and output channels:", chans*(2**(l+1)), chans*(2**l)) + self.fuse_conv_layers1.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + self.fuse_conv_layers2.append(ConvBlock(chans*(2**(l+2)), chans*(2**l), drop_prob)) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # transformer fusion modules + self.n_anchors = 15 + self.avgpool = nn.AdaptiveAvgPool2d((self.n_anchors, self.n_anchors)) + self.transfuse_layer1 = TransFuse_layer(n_embd=self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer2 = TransFuse_layer(n_embd=2*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer3 = TransFuse_layer(n_embd=4*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + self.transfuse_layer4 = TransFuse_layer(n_embd=8*self.chans, n_head=4, block_exp=4, n_layer=8, num_anchors=self.n_anchors) + + self.transfuse_layers = nn.Sequential(self.transfuse_layer1, self.transfuse_layer2, self.transfuse_layer3, self.transfuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, transfuse_layer, net1_fuse_layer, net2_fuse_layer in zip(self.encoder1.down_sample_layers, \ + self.encoder2.down_sample_layers, self.transfuse_layers, self.fuse_conv_layers1, self.fuse_conv_layers2): + + ### extract multi-level features from the encoder of each modality. + output1 = net1_layer(output1) + output2 = net2_layer(output2) + + ### Transformer-based multi-modal fusion layer + feature1 = self.avgpool(output1) + feature2 = self.avgpool(output2) + feature1, feature2 = transfuse_layer(feature1, feature2) + feature1 = F.interpolate(feature1, scale_factor=output1.shape[-1]//self.n_anchors, mode='bilinear') + feature2 = F.interpolate(feature2, scale_factor=output2.shape[-1]//self.n_anchors, mode='bilinear') + + ### 两个模态的特征concat + output1 = net1_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + output2 = net2_fuse_layer(torch.cat((output1, output2, feature1, feature2), 1)) + + stack1.append(output1) + stack2.append(output2) + + ### downsampling features + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_TransFuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/munet_swinfusion.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/munet_swinfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6203c4331744f1da70f18e403360196a419b4cb5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/munet_swinfusion.py @@ -0,0 +1,268 @@ +""" +Xiaohan Xing, 2023/06/12 +两个模态分别用CNN提取多个层级的特征, 每个层级都用TransFuse layer融合。融合后的特征和encoder原始的特征相加。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F +from .SwinFuse_layer import SwinFusion_layer + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # print("decoder layer output:", output.shape) + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + # print("output after transpose conv:", output.shape) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class mUnet_SwinFuse(nn.Module): + """ + 整体框架是multi-modal Unet. 两个模态分别提取各层特征,然后通过SwinFusion的方式融合。 + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.fuse_type = 'swinfuse' + self.img_size = 240 + # self.fuse_type = 'sum' + + self.encoder1 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + self.encoder2 = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder1.ch + + # print("encoder layers:", self.encoder1.down_sample_layers) + + self.conv1 = ConvBlock(ch, ch * 2, drop_prob) + self.conv2 = ConvBlock(ch, ch * 2, drop_prob) + + self.decoder1 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + self.decoder2 = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + # Swin transformer fusion modules + + self.fuse_layer1 = SwinFusion_layer(img_size=(self.img_size//2, self.img_size//2), patch_size=4, window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer2 = SwinFusion_layer(img_size=(self.img_size//4, self.img_size//4), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=2*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer3 = SwinFusion_layer(img_size=(self.img_size//8, self.img_size//8), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=4*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.fuse_layer4 = SwinFusion_layer(img_size=(self.img_size//16, self.img_size//16), window_size=5, Fusion_depths=[2, 2, 2, 2], \ + embed_dim=8*self.chans, Fusion_num_heads=[4, 4, 4, 4]) + self.swinfusion_layers = nn.Sequential(self.fuse_layer1, self.fuse_layer2, self.fuse_layer3, self.fuse_layer4) + + # print("length of encoder layers:", len(self.encoder1.down_sample_layers), len(self.encoder1.down_sample_layers), len(self.transfuse_layers)) + + + + def forward(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # output = self.encoder1.(output) + # output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output1, output2 = image1, image2 + stack1, stack2 = [], [] + + ### 两个模态分别用CNN layer提取特征,然后用transfuse layer融合,将融合前后的特征相加送到后续的layers. + for net1_layer, net2_layer, swinfuse_layer in zip(self.encoder1.down_sample_layers, self.encoder2.down_sample_layers, self.swinfusion_layers): + output1 = net1_layer(output1) + stack1.append(output1) + output1 = F.avg_pool2d(output1, kernel_size=2, stride=2, padding=0) + + output2 = net2_layer(output2) + stack2.append(output2) + output2 = F.avg_pool2d(output2, kernel_size=2, stride=2, padding=0) + + # print("output of encoders:", output1.shape, output2.shape) + # print("CNN feature range:", output1.max(), output2.min()) + + ### Swin transformer-based multi-modal fusion. + feature1, feature2 = swinfuse_layer(output1, output2) + output1 = (output1 + feature1 + output2 + feature2)/4.0 + output2 = (output1 + feature1 + output2 + feature2)/4.0 + + # output = torch.cat((output1, output2), 1) + output1 = self.conv1(output1) + output2 = self.conv2(output2) + output1 = self.decoder1(output1, stack1) + output2 = self.decoder2(output2, stack2) + + return output1, output2 + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return mUnet_SwinFuse(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/original_MINet.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/original_MINet.py new file mode 100644 index 0000000000000000000000000000000000000000..5a115872320f677d8b34264eed95c22fff0df9c4 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/original_MINet.py @@ -0,0 +1,335 @@ +from fastmri.models import common +import torch +import torch.nn as nn +import torch.nn.functional as F + +def make_model(args, parent=False): + return SR_Branch(args) + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class LAM_Module(nn.Module): + """ Layer attention module""" + def __init__(self, in_dim): + super(LAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, N, C, height, width = x.size() + proj_query = x.view(m_batchsize, N, -1) + proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, N, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, N, C, height, width) + + out = self.gamma*out + x + out = out.view(m_batchsize, -1, height, width) + return out + +class CSAM_Module(nn.Module): + """ Channel-Spatial attention module""" + def __init__(self, in_dim): + super(CSAM_Module, self).__init__() + self.chanel_in = in_dim + + + self.conv = nn.Conv3d(1, 1, 3, 1, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + #self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + def forward(self,x): + """ + inputs : + x : input feature maps( B X N X C X H X W) + returns : + out : attention value + input feature + attention: B X N X N + """ + m_batchsize, C, height, width = x.size() + out = x.unsqueeze(1) + out = self.sigmoid(self.conv(out)) + + out = self.gamma*out + out = out.view(m_batchsize, -1, height, width) + x = x * out + x + return x + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + #res = self.body(x).mul(self.res_scale) + res += x + return res + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class SR_Branch(nn.Module): + def __init__(self,n_resgroups,n_resblocks,n_feats,conv=common.default_conv): + super(SR_Branch, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + kernel_size = 3 + reduction = 16 + + scale = 2 + rgb_range = 255 + n_colors = 1 + res_scale =0.1 + act = nn.ReLU(True) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + common.Upsampler(conv, scale, n_feats, act=False), + conv(n_feats,n_feats, kernel_size)]#n_colors + + self.add_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.csa = CSAM_Module(n_feats) + self.la = LAM_Module(n_feats) + self.last_conv = nn.Conv2d(n_feats*(n_resgroups+1), n_feats, 3, 1, 1) + self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1) + self.last1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) + self.tail = nn.Sequential(*modules_tail) + self.final = nn.Conv2d(n_feats, n_colors, 3, 1, 1) + + def forward(self, x): + outputs = [] + x = self.head(x) + outputs.append(x) + + res = x + + for name, midlayer in self.body._modules.items(): + res = midlayer(res) + + if name=='0': + res1 = res.unsqueeze(1) + else: + res1 = torch.cat([res.unsqueeze(1),res1],1) + + outputs.append(res1) + + out1 = res + res = self.la(res1) + out2 = self.last_conv(res) + out1 = self.csa(out1) + out = torch.cat([out1, out2], 1) + res = self.last(out) + + res += x + + outputs.append(res) + + x = self.tail(res) + + return outputs, x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) +class Pred_Layer(nn.Module): + def __init__(self, in_c=32): + super(Pred_Layer, self).__init__() + self.enlayer = nn.Sequential( + nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.outlayer = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0), ) + + def forward(self, x): + x = self.enlayer(x) + x = self.outlayer(x) + return x + +class MINet(nn.Module): + def __init__(self, n_resgroups,n_resblocks, n_feats): + super(MINet, self).__init__() + + self.n_resgroups = n_resgroups + self.n_resblocks = n_resblocks + self.n_feats = n_feats + self.net1 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + self.net2 = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + main_net = SR_Branch( + n_resgroups = self.n_resgroups, + n_resblocks = self.n_resblocks, + n_feats = self.n_feats, + ) + + + self.body = main_net.body + self.csa = main_net.csa + self.la = main_net.la + self.last_conv = main_net.last_conv + self.last = main_net.last + self.last1 = main_net.last1 + self.tail = main_net.tail + self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1) + nlayer = len(self.net1.body._modules.items()) + self.fusion_convs = nn.ModuleList([nn.Conv2d(128, 64, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT1 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.fusion_convsT2 = nn.ModuleList([nn.Conv2d(64, 32, kernel_size=1, padding=0) for i in range(nlayer)]) + self.map_convs = nn.ModuleList([nn.Conv2d(64, 1, kernel_size=3, padding=1) for i in range(nlayer)]) + self.rgbd_global = Pred_Layer(32 * 2) + + def forward(self, x1, x2): + + x1 = self.net1.head(x1) + x2 = self.net2.head(x2) + + x2 = self.tail(x2) + + resT1 = x1 + resT2 = x2 + + t1s = [] + t2s = [] + + for m1, m2,fusion_conv in zip(self.net1.body._modules.items(),self.net2.body._modules.items(), self.fusion_convs): + name1, midlayer1 = m1 + _, midlayer2 = m2 + + resT1 = midlayer1(resT1) + resT2 = midlayer2(resT2) + + t1s.append(resT1.unsqueeze(1)) + t2s.append(resT2.unsqueeze(1)) + + res = torch.cat([resT1,resT2],dim=1) + res = fusion_conv(res) + + resT2 = res+resT2 + + out1T1 = resT1 + out1T2 = resT2 + + ts = t1s + t2s + ts = torch.cat(ts,dim=1) + res1_T2 = self.net2.la(ts) + out2_T2 = self.net2.last_conv(res1_T2) + + out1T1 = self.net1.csa(out1T1) + out1T2 = self.net2.csa(out1T2) + + outT2 = torch.cat([out1T2, out2_T2], 1) + resT1 = self.net1.last1(out1T1) + resT2 = self.net2.last(outT2) + + resT1 += x1 + resT2 += x2 + + x1 = self.net1.final(resT1) + x2 = self.net2.final(resT2) + return x1,x2#x1=pd x2=pdfs + \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..3699cc809e438be4e1bd5efaba7d96e18d7f1602 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/restormer.py @@ -0,0 +1,307 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + # print("feature before FFN projection:", x.max(), x.min()) + x = self.project_out(x) + # print("FFN output:", x.max(), x.min()) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + b,c,h,w = x.shape + # print("Attention block input:", x.max(), x.min()) + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + # print("attention values:", attn) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("Attention block output:", out.max(), out.min()) + + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): + super(TransformerBlock, self).__init__() + + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + # dim = 32, + # num_blocks = [4,4,4,4], + dim = 32, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + # self.tanh = nn.Tanh() + + def forward(self, inp_img): + + # stack = [] + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + # print("encoder block1 feature:", out_enc_level1.shape, out_enc_level1.max(), out_enc_level1.min()) + # stack.append(out_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + # print("encoder block2 feature:", out_enc_level2.shape, out_enc_level2.max(), out_enc_level2.min()) + # stack.append(out_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + # print("encoder block3 feature:", out_enc_level3.shape, out_enc_level3.max(), out_enc_level3.min()) + # stack.append(out_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + # print("encoder block4 feature:", latent.shape, latent.max(), latent.min()) + # stack.append(latent) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + # out_dec_level1 = self.tanh(out_dec_level1) + + + return out_dec_level1 #, stack + + + +def build_model(args): + return Restormer() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/restormer_block.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/restormer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..170ae6a826e5a3e356cb48f8c89500c2d5242816 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/restormer_block.py @@ -0,0 +1,321 @@ +## Restormer: Efficient Transformer for High-Resolution Image Restoration +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang +## https://arxiv.org/abs/2111.09881 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pdb import set_trace as stx +import numbers + +from einops import rearrange + + + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + # print("input to the LayerNorm layer:", x.max(), x.min()) + h, w = x.shape[-2:] + x = to_4d(self.body(to_3d(x)), h, w) + # print("output of the LayerNorm:", x.max(), x.min()) + return x + + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim*ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + # self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.temperature = 1.0 / ((dim//num_heads) ** 0.5) + + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + + def forward(self, x): + # print("input to the Attention layer:", x.max(), x.min()) + + b,c,h,w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + # print("q range:", q.max(), q.min()) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + # print("scaling parameter:", self.temperature) + # print("attention matrix before softmax:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + attn = attn.softmax(dim=-1) + # print("attention matrix range:", attn[0,0,:,:].max(), attn[0,0,:,:].min()) + # print("v range:", v.max(), v.min(), v.mean()) + + out = (attn @ v) + # print("self-attention output:", out.max(), out.min()) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + # print("output of attention layer:", out.max(), out.min()) + return out + + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, pre_norm): + super(TransformerBlock, self).__init__() + + self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + self.pre_norm = pre_norm + + def forward(self, x): + x = self.conv(x) + # print("restormer block input:", x.max(), x.min()) + if self.pre_norm: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + else: + x = self.norm1(x + self.attn(x)) + x = self.norm2(x + self.ffn(x)) + # x = self.norm1(self.attn(x)) + # x = self.norm2(self.ffn(x)) + # x = x + self.attn(x) + # x = x + self.ffn(x) + # print("restormer block output:", x.max(), x.min()) + + return x + + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) + + +########################################################################## +##---------- Restormer ----------------------- +class Restormer(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + dim = 48, + num_blocks = [4,6,6,8], + num_refinement_blocks = 4, + heads = [1,2,4,8], + ffn_expansion_factor = 2.66, + bias = False, + LayerNorm_type = 'WithBias', ## Other option 'BiasFree' + dual_pixel_task = False ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 + ): + + super(Restormer, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) + + self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) + + + self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, inp_img): + + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + + + return out_dec_level1 + + + + +if __name__ == '__main__': + model = Restormer() + # print(model) + + A = torch.randn((1, 1, 240, 240)) + x = model(A) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/swinIR.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/swinIR.py new file mode 100644 index 0000000000000000000000000000000000000000..5cbfc9c175c02531ac80a05ba4abfc0776f752dd --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/swinIR.py @@ -0,0 +1,874 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 4, 4, 6], + embed_dim=32, num_heads=[6, 6, 6, 6], mlp_ratio=4, upsampler='pixelshuffledirect') + + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = SwinIR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/swin_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..877bdfbffc91012e4ac31f41b838ebb116f4a825 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/swin_transformer.py @@ -0,0 +1,878 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=1, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=1, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + # print("shallow feature:", x.shape) + x = self.conv_after_body(self.forward_features(x)) + x + # print("deep feature:", x.shape) + x = self.upsample(x) + # print("after upsample:", x.shape) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +def build_model(args): + # return SwinIR(upscale=1, img_size=(240, 240), + # window_size=8, img_range=1., depths=[6, 6, 6, 6], + # embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + return SwinIR(upscale=1, img_size=(240, 240), + window_size=8, img_range=1., depths=[2, 2, 2], + embed_dim=60, num_heads=[6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + +if __name__ == '__main__': + height = 240 + width = 240 + window_size = 8 + model = SwinIR(upscale=1, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + + x = torch.randn((1, 1, height, width)) + print("input:", x.shape) + x = model(x) + print("output:", x.shape) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/trans_unet.zip b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/trans_unet.zip new file mode 100644 index 0000000000000000000000000000000000000000..e7b71c1287551fd6119efd7c62da2196307998f5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/trans_unet.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:369de2a28036532c446409ee1f7b953d5763641ffd452675613e1bb9bba89eae +size 19354 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/trans_unet/trans_unet.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/trans_unet/trans_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..fa61dbadfcaee9b26594307adc90cdd1d2a84e3e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/trans_unet/trans_unet.py @@ -0,0 +1,489 @@ +# coding=utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import logging +import math + +from os.path import join as pjoin + +import torch +import torch.nn as nn +import numpy as np + +from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm +from torch.nn.modules.utils import _pair +from scipy import ndimage +from . import vit_seg_configs as configs +# import .vit_seg_configs as configs +from .vit_seg_modeling_resnet_skip import ResNetV2 + + +logger = logging.getLogger(__name__) + + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class Attention(nn.Module): + def __init__(self, config, vis): + super(Attention, self).__init__() + self.vis = vis + self.num_attention_heads = config.transformer["num_heads"] + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(config.hidden_size, self.all_head_size) + self.key = Linear(config.hidden_size, self.all_head_size) + self.value = Linear(config.hidden_size, self.all_head_size) + + self.out = Linear(config.hidden_size, config.hidden_size) + self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) + self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + + +class Mlp(nn.Module): + def __init__(self, config): + super(Mlp, self).__init__() + self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) + self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(config.transformer["dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings. + """ + def __init__(self, config, img_size, in_channels=3): + super(Embeddings, self).__init__() + self.hybrid = None + self.config = config + img_size = _pair(img_size) + + if config.patches.get("grid") is not None: # ResNet + grid_size = config.patches["grid"] + # print("image size:", img_size, "grid size:", grid_size) + patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) + patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) + # print("patch_size_real", patch_size_real) + n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) + self.hybrid = True + else: + patch_size = _pair(config.patches["size"]) + n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) + self.hybrid = False + + if self.hybrid: + self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) + in_channels = self.hybrid_model.width * 16 + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=config.hidden_size, + kernel_size=patch_size, + stride=patch_size) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) + + self.dropout = Dropout(config.transformer["dropout_rate"]) + + + def forward(self, x): + # print("input x:", x.shape) ### (4, 3, 240, 240) + if self.hybrid: + x, features = self.hybrid_model(x) + else: + features = None + # print("output of the hybrid model:", x.shape) ### (4, 1024, 15, 15) + x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) + x = x.flatten(2) + x = x.transpose(-1, -2) # (B, n_patches, hidden) + + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings, features + + +class Block(nn.Module): + def __init__(self, config, vis): + super(Block, self).__init__() + self.hidden_size = config.hidden_size + self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn = Mlp(config) + self.attn = Attention(config, vis) + + def forward(self, x): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + def load_from(self, weights, n_block): + ROOT = f"Transformer/encoderblock_{n_block}" + with torch.no_grad(): + query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() + key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() + value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() + out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() + + query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) + key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) + value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) + out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) + + self.attn.query.weight.copy_(query_weight) + self.attn.key.weight.copy_(key_weight) + self.attn.value.weight.copy_(value_weight) + self.attn.out.weight.copy_(out_weight) + self.attn.query.bias.copy_(query_bias) + self.attn.key.bias.copy_(key_bias) + self.attn.value.bias.copy_(value_bias) + self.attn.out.bias.copy_(out_bias) + + mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() + mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() + mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() + mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() + + self.ffn.fc1.weight.copy_(mlp_weight_0) + self.ffn.fc2.weight.copy_(mlp_weight_1) + self.ffn.fc1.bias.copy_(mlp_bias_0) + self.ffn.fc2.bias.copy_(mlp_bias_1) + + self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) + self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) + self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) + self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) + + +class Encoder(nn.Module): + def __init__(self, config, vis): + super(Encoder, self).__init__() + self.vis = vis + self.layer = nn.ModuleList() + self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) + for _ in range(config.transformer["num_layers"]): + layer = Block(config, vis) + self.layer.append(copy.deepcopy(layer)) + + def forward(self, hidden_states): + attn_weights = [] + for layer_block in self.layer: + hidden_states, weights = layer_block(hidden_states) + if self.vis: + attn_weights.append(weights) + encoded = self.encoder_norm(hidden_states) + return encoded, attn_weights + + +class Transformer(nn.Module): + def __init__(self, config, img_size, vis): + super(Transformer, self).__init__() + self.embeddings = Embeddings(config, img_size=img_size) + self.encoder = Encoder(config, vis) + # print(self.encoder) + + def forward(self, input_ids): + embedding_output, features = self.embeddings(input_ids) ### "features" store the features from different layers. + # print("embedding_output:", embedding_output.shape) ## (B, n_patch, hidden) + encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) + # print("encoded feature:", encoded.shape) + return encoded, attn_weights, features + + # return embedding_output, features + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + bn = nn.BatchNorm2d(out_channels) + + super(Conv2dReLU, self).__init__(conv, bn, relu) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels=0, + use_batchnorm=True, + ): + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + + def forward(self, x, skip=None): + x = self.up(x) + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class ReconstructionHead(nn.Sequential): + + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() + super().__init__(conv2d, upsampling) + + +class DecoderCup(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + head_channels = 512 + self.conv_more = Conv2dReLU( + config.hidden_size, + head_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + decoder_channels = config.decoder_channels + in_channels = [head_channels] + list(decoder_channels[:-1]) + out_channels = decoder_channels + + if self.config.n_skip != 0: + skip_channels = self.config.skip_channels + for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip + skip_channels[3-i]=0 + + else: + skip_channels=[0,0,0,0] + + blocks = [ + DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) + ] + self.blocks = nn.ModuleList(blocks) + # print(self.blocks) + + def forward(self, hidden_states, features=None): + B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) + h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) + x = hidden_states.permute(0, 2, 1) + x = x.contiguous().view(B, hidden, h, w) + x = self.conv_more(x) + for i, decoder_block in enumerate(self.blocks): + if features is not None: + skip = features[i] if (i < self.config.n_skip) else None + else: + skip = None + x = decoder_block(x, skip=skip) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): + super(VisionTransformer, self).__init__() + self.num_classes = num_classes + self.zero_head = zero_head + self.classifier = config.classifier + self.transformer = Transformer(config, img_size, vis) + self.decoder = DecoderCup(config) + self.segmentation_head = ReconstructionHead( + in_channels=config['decoder_channels'][-1], + out_channels=config['n_classes'], + kernel_size=3, + ) + self.activation = nn.Tanh() + self.config = config + + def forward(self, x): + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + ### hybrid CNN and transformer to extract features from the input images. + x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) + + # ### extract features with CNN only. + # x, features = self.transformer(x) # (B, n_patch, hidden) + + x = self.decoder(x, features) + # logits = self.segmentation_head(x) + logits = self.activation(self.segmentation_head(x)) + # print("logits:", logits.shape) + return logits + + def load_from(self, weights): + with torch.no_grad(): + + res_weight = weights + self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) + self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) + + self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) + self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) + + posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) + + posemb_new = self.transformer.embeddings.position_embeddings + if posemb.size() == posemb_new.size(): + self.transformer.embeddings.position_embeddings.copy_(posemb) + elif posemb.size()[1]-1 == posemb_new.size()[1]: + posemb = posemb[:, 1:] + self.transformer.embeddings.position_embeddings.copy_(posemb) + else: + logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) + ntok_new = posemb_new.size(1) + if self.classifier == "seg": + _, posemb_grid = posemb[:, :1], posemb[0, 1:] + gs_old = int(np.sqrt(len(posemb_grid))) + gs_new = int(np.sqrt(ntok_new)) + print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb = posemb_grid + self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) + + # Encoder whole + for bname, block in self.transformer.encoder.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=uname) + + if self.transformer.embeddings.hybrid: + self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) + gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) + gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) + self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) + self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) + + for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(res_weight, n_block=bname, n_unit=uname) + +CONFIGS = { + 'ViT-B_16': configs.get_b16_config(), + 'ViT-B_32': configs.get_b32_config(), + 'ViT-L_16': configs.get_l16_config(), + 'ViT-L_32': configs.get_l32_config(), + 'ViT-H_14': configs.get_h14_config(), + 'R50-ViT-B_16': configs.get_r50_b16_config(), + 'R50-ViT-L_16': configs.get_r50_l16_config(), + 'testing': configs.get_testing(), +} + + + +def build_model(args): + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + print(net) + # net.load_from(weights=np.load(config_vit.pretrained_path)) + return net + + +if __name__ == "__main__": + config_vit = CONFIGS['R50-ViT-B_16'] + net = VisionTransformer(config_vit, img_size=240, num_classes=1).cuda() + net.load_from(weights=np.load(config_vit.pretrained_path)) + + input_data = torch.rand(4, 1, 240, 240).cuda() + print("input:", input_data.shape) + output = net(input_data) + print("output:", output.shape) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/trans_unet/vit_seg_configs.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/trans_unet/vit_seg_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..bed7d3346e30328a38ac6fef1621f788d905df1e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/trans_unet/vit_seg_configs.py @@ -0,0 +1,132 @@ +import ml_collections + +def get_b16_config(): + """Returns the ViT-B/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 768 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 3072 + config.transformer.num_heads = 12 + config.transformer.num_layers = 4 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + + config.classifier = 'seg' + config.representation_size = None + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' + config.patch_size = 16 + + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_testing(): + """Returns a minimal configuration for testing.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 1 + config.transformer.num_heads = 1 + config.transformer.num_layers = 1 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config + + +def get_r50_b16_config(): + """Returns the Resnet50 + ViT-B/16 configuration.""" + config = get_b16_config() + # config.patches.grid = (16, 16) + config.patches.grid = (15, 15) ### for the input image with size (240, 240) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'recon' + config.pretrained_path = '/home/xiaohan/workspace/MSL_MRI/code/pretrained_model/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 1 + config.n_skip = 3 + config.activation = 'tanh' + + return config + + +def get_b32_config(): + """Returns the ViT-B/32 configuration.""" + config = get_b16_config() + config.patches.size = (32, 32) + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' + return config + + +def get_l16_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1024 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 4096 + config.transformer.num_heads = 16 + config.transformer.num_layers = 24 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.representation_size = None + + # custom + config.classifier = 'seg' + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_r50_l16_config(): + """Returns the Resnet50 + ViT-L/16 configuration. customized """ + config = get_l16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_l32_config(): + """Returns the ViT-L/32 configuration.""" + config = get_l16_config() + config.patches.size = (32, 32) + return config + + +def get_h14_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (14, 14)}) + config.hidden_size = 1280 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 5120 + config.transformer.num_heads = 16 + config.transformer.num_layers = 32 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + + return config \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae80ef7d96c46cb91bda849901aef34008e62ed --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/trans_unet/vit_seg_modeling_resnet_skip.py @@ -0,0 +1,160 @@ +import math + +from os.path import join as pjoin +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +class StdConv2d(nn.Conv2d): + + def forward(self, x): + w = self.weight + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / torch.sqrt(v + 1e-5) + return F.conv2d(x, w, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +def conv3x3(cin, cout, stride=1, groups=1, bias=False): + return StdConv2d(cin, cout, kernel_size=3, stride=stride, + padding=1, bias=bias, groups=groups) + + +def conv1x1(cin, cout, stride=1, bias=False): + return StdConv2d(cin, cout, kernel_size=1, stride=stride, + padding=0, bias=bias) + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + """ + + def __init__(self, cin, cout=None, cmid=None, stride=1): + super().__init__() + cout = cout or cin + cmid = cmid or cout//4 + + self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv1 = conv1x1(cin, cmid, bias=False) + self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! + self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) + self.conv3 = conv1x1(cmid, cout, bias=False) + self.relu = nn.ReLU(inplace=True) + + if (stride != 1 or cin != cout): + # Projection also with pre-activation according to paper. + self.downsample = conv1x1(cin, cout, stride, bias=False) + self.gn_proj = nn.GroupNorm(cout, cout) + + def forward(self, x): + + # Residual branch + residual = x + if hasattr(self, 'downsample'): + residual = self.downsample(x) + residual = self.gn_proj(residual) + + # Unit's branch + y = self.relu(self.gn1(self.conv1(x))) + y = self.relu(self.gn2(self.conv2(y))) + y = self.gn3(self.conv3(y)) + + y = self.relu(residual + y) + return y + + def load_from(self, weights, n_block, n_unit): + conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) + conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) + conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) + + gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) + gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) + + gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) + gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) + + gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) + gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) + + self.conv1.weight.copy_(conv1_weight) + self.conv2.weight.copy_(conv2_weight) + self.conv3.weight.copy_(conv3_weight) + + self.gn1.weight.copy_(gn1_weight.view(-1)) + self.gn1.bias.copy_(gn1_bias.view(-1)) + + self.gn2.weight.copy_(gn2_weight.view(-1)) + self.gn2.bias.copy_(gn2_bias.view(-1)) + + self.gn3.weight.copy_(gn3_weight.view(-1)) + self.gn3.bias.copy_(gn3_bias.view(-1)) + + if hasattr(self, 'downsample'): + proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) + proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) + proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) + + self.downsample.weight.copy_(proj_conv_weight) + self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) + self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode.""" + + def __init__(self, block_units, width_factor): + super().__init__() + width = int(64 * width_factor) + self.width = width + + self.root = nn.Sequential(OrderedDict([ + ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), + ('gn', nn.GroupNorm(32, width, eps=1e-6)), + ('relu', nn.ReLU(inplace=True)), + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) + ])) + + self.body = nn.Sequential(OrderedDict([ + ('block1', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], + ))), + ('block2', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], + ))), + ('block3', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], + ))), + ])) + + def forward(self, x): + features = [] + b, c, in_size, _ = x.size() + x = self.root(x) + features.append(x) + x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) + for i in range(len(self.body)-1): + x = self.body[i](x) + right_size = int(in_size / 4 / (i+1)) + if x.size()[2] != right_size: + pad = right_size - x.size()[2] + assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) + feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) + feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] + else: + feat = x + features.append(feat) + x = self.body[-1](x) + return x, features[::-1] \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/transformer_modules.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/transformer_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..900042884305619e39ad9b792b3b1936dfbd37b3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/transformer_modules.py @@ -0,0 +1,252 @@ +import math +import warnings + +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + # two mlp, fc-relu-drop-fc-relu-drop + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention_Encoder(nn.Module): + def __init__(self, dim, kv_reduced_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if kv_reduced_dim is not None and type(kv_reduced_dim) == int: + self.fc_k = nn.Linear() + + def forward(self, x): + B, N, C = x.shape + # qkv shape [3, N, num_head, HW, C//num_head] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [N, num_head, HW, C//num_head] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Attention_Decoder(nn.Module): + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.fc_q = nn.Linear(dim, dim * 1, bias=qkv_bias) + self.fc_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, x): + # q:[B,12,256] x:[B,HW,256] + B, N, C = x.shape + n_class = q.shape[1] + + q = self.fc_q(q).reshape(B, self.num_heads, n_class, C // self.num_heads) + kv = self.fc_kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # [B, num_head, HW, 256/num_head] + + attn1 = (q @ k.transpose(-2, -1)) * self.scale # [B, num_head, 12, HW] + attn2 = attn1.softmax(dim=-1) + attn3 = self.attn_drop(attn2) # [B, num_head, 11, HW] + + x = (attn3 @ v).reshape(B, n_class, C) + x = self.proj(x) + x = self.proj_drop(x) # [B, 12, 256] + + # attn = attn3.permute(0, 2, 1, 3) + attn = attn1.permute(0, 2, 1, 3) + # attn = attn2.permute(0, 2, 1, 3) + return attn, x + + +class Block_Encoder(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_Encoder( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, embed_dim=768, depth=4, + num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.transformer_encoder = nn.ModuleList([ + Block_Encoder( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_embed'} + + def forward(self, x): + + x = self.pos_drop(x) + for blk in self.transformer_encoder: + x = blk(x) + + x = self.norm(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +def build_transformer(embed_dim): + return VisionTransformer(embed_dim=embed_dim, + depth=4, + num_heads=4, + mlp_ratio=3) + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/unet.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..d6dd935cd34a121d8079a8f5184d4ea29e15c8da --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/unet.py @@ -0,0 +1,196 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +from torch.nn import functional as F + + +class Unet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + feature_stack = [] + output = image + # print("image shape:", image.shape) + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + feature_stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + # print("unet feature range:", output.max(), output.min()) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, feature_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/unet_restormer.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/unet_restormer.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f78aa5f192957b2706c6f691dfc66c427b65ae --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/unet_restormer.py @@ -0,0 +1,234 @@ +""" +2023/07/12, +将两个模态的图像concat作为input image, 在Unet的每层特征后面都连接一个restormer block提取channel-wise long-range dependent feature. +""" + +import torch +from torch import nn +from torch.nn import functional as F +from .restormer_block import TransformerBlock + + +class Unet_Restormer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + self.img_size = 240 + + self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, self.drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + + num_blocks=[2, 2, 2, 2] + heads=[1, 2, 4, 8] + ffn_expansion_factor = 2.66 + + self.restor_block1 = nn.ModuleList([TransformerBlock(dim=chans, num_heads=heads[0], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[0])]) + self.restor_block2 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 1), num_heads=heads[1], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[1])]) + self.restor_block3 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 2), num_heads=heads[2], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[2])]) + self.restor_block4 = nn.ModuleList([TransformerBlock(dim=int(chans * 2 ** 3), num_heads=heads[3], \ + ffn_expansion_factor=ffn_expansion_factor, bias=False, LayerNorm_type='WithBias', pre_norm=False) for i in range(num_blocks[3])]) + self.restormer_blocks = nn.Sequential(self.restor_block1, self.restor_block2, self.restor_block3, self.restor_block4) + + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + # nn.Tanh(), + ) + ) + + # def forward(self, image: torch.Tensor) -> torch.Tensor: + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + # print("image shape:", image.shape) + unet_features_stack, restormer_features_stack = [], [] + + # print("input image range:", output.max(), output.min()) + + # apply down-sampling layers + for layer, restor_block in zip(self.down_sample_layers, self.restormer_blocks): + output = layer(output) + # print("Unet feature range:", output.max(), output.min()) + + feature = output + for restormer_layer in restor_block: + feature = restormer_layer(feature) + # print("intermediate feature1:", feature1.shape) + + unet_features_stack.append(output) + restormer_features_stack.append(feature) + + # output = (output + feature)/2.0 + output = feature + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + # print("downsampling layer:", output.shape) + + output = self.conv(output) + # print("intermediate layer output:", output.shape) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + # print("unsampling layer:", output.shape) + + return output #, unet_features_stack, restormer_features_stack + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + # for layer in range(len(self.layers)): + # print(self.layers[layer]) + # image = self.layers[layer](image) + # # print("feature range of this layer:", image.max(), image.min()) + # return image + return self.layers(image) + + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + # nn.BatchNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Restormer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/unet_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/unet_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..982fe23cec321de6b636e04c51b62037054ea432 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/unet_transformer.py @@ -0,0 +1,346 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Xiaohan Xing, 2023/05/23 +对于Unet中的high-level feature, 用Transformer进行特征交互,然后和Transformer之前的特征一起送到decoder重建图像。 +""" +# coding: utf-8 +from typing import Any +import torch +from torch import nn +import numpy as np +from torch.nn import functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +####################################################### +########## Transformer for feature enhancement ######## +####################################################### +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def get_feature(self, x): + for idx, (attn, ff) in enumerate(self.layers): + x = attn(x) + x + x = ff(x) + x + if idx == 0: + return x + # return x + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + + +################ Feature Extractor ############## + +class Encoder(nn.Module): + def __init__(self, num_pool_layers, in_chans, chans, drop_prob): + super(Encoder, self).__init__() + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.ch = ch + + def forward(self, x): + stack = [] + output = x + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + return stack, output + + + + +class Decoder(nn.Module): + def __init__(self, num_pool_layers, ch, out_chans, drop_prob): + super(Decoder, self).__init__() + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, out_chans, kernel_size=1, stride=1), + nn.Tanh(), + ) + ) + + def forward(self, x, stack): + output = x + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + + + + +class Unet_Transformer(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234 241. + Springer, 2015. + """ + + def __init__( + self, args, + input_dim: int = 1, + output_dim: int = 1, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + # def __init__(self, args): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = input_dim + self.out_chans = output_dim + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.encoder = Encoder(self.num_pool_layers, self.in_chans, self.chans, self.drop_prob) + ch = self.encoder.ch + self.conv = ConvBlock(ch, ch * 2, drop_prob) + self.decoder = Decoder(self.num_pool_layers, ch, self.out_chans, self.drop_prob) + + + # transformer fusion modules + fmp_size = 15 # feature map after encoder [bs,encoder.output_dim,fmp_size,fmp_size]=[4,256,96,96] + num_patch = fmp_size * fmp_size + patch_dim = ch + + self.to_patch_embedding = Rearrange('b e (h) (w) -> b (h w) e') + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patch + 1, patch_dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, patch_dim)) + self.dropout = nn.Dropout(0.1) + + self.transformer = Transformer(patch_dim, depth=2, heads=8, dim_head=64, mlp_dim=3072, dropout=0.1) + + + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + output = image + # print("image shape:", image.shape) + + ### CNN encoder提取high-level feature + stack, output = self.encoder(output) ### output size: [4, 256, 15, 15] + + ### 用transformer进行feature enhancement, 然后和前面提取的特征concat. + m_embed = self.to_patch_embedding(output) # [4, 225, 256], 225 is the number of patches. + patch_embed_input = F.normalize(m_embed, p=2.0, dim=-1, eps=1e-12) + b, n, _ = m_embed.shape + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) + + patch_embed_input = torch.cat((cls_tokens, patch_embed_input), 1) + # print("patch embedding:", patch_embed_input.shape, "pos embedding:", self.pos_embedding.shape) + patch_embed_input += self.pos_embedding + patch_embed_input = self.dropout(patch_embed_input) # [4, 1152+1, 1024] + + feature_output = self.transformer(patch_embed_input)[:, 1:, :] # [4, 225, 256] + # print("output of the transformer:", feature_output.shape) + h, w = int(np.sqrt(feature_output.shape[1])), int(np.sqrt(feature_output.shape[1])) + feature_output = feature_output.contiguous().view(b, feature_output.shape[-1], h, w) # [4,512*2,24,24] + + output = torch.cat((output, feature_output), 1) + # print("feature:", output.shape) + + ### 用CNN和transformer联合的特征送到decoder进行图像重建 + output = self.decoder(output, stack) + + return output + + + +class ConvBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + return self.layers(image) + + +class TransposeConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose + layers followed by instance normalization and LeakyReLU activation. + """ + + def __init__(self, in_chans: int, out_chans: int): + """ + Args: + in_chans: Number of channels in the input. + out_chans: Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H*2, W*2)`. + """ + return self.layers(image) + + + +def build_model(args): + return Unet_Transformer(args) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/unimodal_transformer.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/unimodal_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c99214e10bee2f7a1be49f8560799ffbc14ed0a5 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/compare_models/unimodal_transformer.py @@ -0,0 +1,112 @@ +import torch + +from torch import nn +from .transformer_modules import build_transformer +from einops.layers.torch import Rearrange +from .MINet_common import default_conv as conv, Upsampler + + +# add by YunluYan + + +class ReconstructionHead(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(ReconstructionHead, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.conv1 = nn.Conv2d(self.input_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) + + self.bn1 = nn.BatchNorm2d(self.hidden_dim) + self.bn2 = nn.BatchNorm2d(self.hidden_dim) + self.bn3 = nn.BatchNorm2d(self.hidden_dim) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.act(out) + + return out + + +def build_head(): + return ReconstructionHead(input_dim=1, hidden_dim=12) + + + +class CMMT(nn.Module): + + def __init__(self, args): + super(CMMT, self).__init__() + + self.head = build_head() + + HEAD_HIDDEN_DIM = 12 + PATCH_SIZE = 16 + INPUT_SIZE = 240 + OUTPUT_DIM = 1 + + patch_dim = HEAD_HIDDEN_DIM* PATCH_SIZE ** 2 + num_patches = (INPUT_SIZE // PATCH_SIZE) ** 2 + + self.transformer = build_transformer(patch_dim) + + self.patch_embbeding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=PATCH_SIZE, p2=PATCH_SIZE), + ) + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, patch_dim)) + + + self.p1 = PATCH_SIZE + self.p2 = PATCH_SIZE + + self.tail = nn.Conv2d(HEAD_HIDDEN_DIM, OUTPUT_DIM, 1) + + + def forward(self, x): + + b,_ , h, w = x.shape + + x = self.head(x) + + x= self.patch_embbeding(x) + + x += self.pos_embedding + + x = self.transformer(x) # b HW p1p2c + + c = int(x.shape[2]/(self.p1*self.p2)) + H = int(h/self.p1) + W = int(w/self.p2) + + x = x.reshape(b, H, W, self.p1, self.p2, c) # b H W p1 p2 c + x = x.permute(0, 5, 1, 3, 2, 4) # b c H p1 W p2 + x = x.reshape(b, -1, h, w,) + x = self.tail(x) + + return x + + +def build_model(args): + return CMMT(args) + + + + + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/frequency_model/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/frequency_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/frequency_model/blocks.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/frequency_model/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..bad92a6d6709130c4ad1cc49d5094958d44d7e71 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/frequency_model/blocks.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + +USE_PYTORCH_IN = False + + +###################################################################### +# Superclass of all Modules that take two inputs +###################################################################### +class TwoInputModule(nn.Module): + def forward(self, input1, input2): + raise NotImplementedError + +###################################################################### +# A (sort of) hacky way to create a module that takes two inputs (e.g. x and z) +# and returns one output (say o) defined as follows: +# o = module2.forward(module1.forward(x), z) +# Note that module2 MUST support two inputs as well. +###################################################################### +class MergeModule(TwoInputModule): + def __init__(self, module1, module2): + """ module1 could be any module (e.g. Sequential of several modules) + module2 must accept two inputs + """ + super(MergeModule, self).__init__() + self.module1 = module1 + self.module2 = module2 + + def forward(self, input1, input2): + output1 = self.module1.forward(input1) + output2 = self.module2.forward(output1, input2) + return output2 + +###################################################################### +# A (sort of) hacky way to create a container that takes two inputs (e.g. x and z) +# and applies a sequence of modules (exactly like nn.Sequential) but MergeModule +# is one of its submodules it applies it to both inputs +###################################################################### +class TwoInputSequential(nn.Sequential, TwoInputModule): + def __init__(self, *args): + super(TwoInputSequential, self).__init__(*args) + + def forward(self, input1, input2): + """overloads forward function in parent calss""" + + for module in self._modules.values(): + if isinstance(module, TwoInputModule): + input1 = module.forward(input1, input2) + else: + input1 = module.forward(input1) + return input1 + + +###################################################################### +# A standard instance norm module. +# Since the pytorch instance norm used BatchNorm as a base and thus is +# different from the standard implementation. +###################################################################### +class InstanceNorm(nn.Module): + def __init__(self, num_features, affine=True, eps=1e-5): + """`num_features` number of feature channels + """ + super(InstanceNorm, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + + self.scale = Parameter(torch.Tensor(num_features)) + self.shift = Parameter(torch.Tensor(num_features)) + + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + self.scale.data.normal_(mean=0., std=0.02) + self.shift.data.zero_() + + def forward(self, input): + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + centered_x = x_reshaped - mean + std = torch.rsqrt((centered_x ** 2).mean(2, keepdim=True) + self.eps) + norm_features = (centered_x * std).view(*size) + + # broadcast on the batch dimension, hight and width dimensions + if self.affine: + output = norm_features * self.scale[:,None,None] + self.shift[:,None,None] + else: + output = norm_features + + return output + +InstanceNorm2d = nn.InstanceNorm2d if USE_PYTORCH_IN else InstanceNorm + +###################################################################### +# A module implementing conditional instance norm. +# Takes two inputs: x (input features) and z (latent codes) +###################################################################### +class CondInstanceNorm(TwoInputModule): + def __init__(self, x_dim, z_dim, eps=1e-5): + """`x_dim` dimensionality of x input + `z_dim` dimensionality of z latents + """ + super(CondInstanceNorm, self).__init__() + self.eps = eps + self.shift_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + self.scale_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + + def forward(self, input, noise): + + shift = self.shift_conv.forward(noise) + scale = self.scale_conv.forward(noise) + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + var = x_reshaped.var(2, keepdim=True) + std = torch.rsqrt(var + self.eps) + norm_features = ((x_reshaped - mean) * std).view(*size) + output = norm_features * scale + shift + return output + + +###################################################################### +# A modified resnet block which allows for passing additional noise input +# to be used for conditional instance norm +###################################################################### +class CINResnetBlock(TwoInputModule): + def __init__(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + super(CINResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + for idx, module in enumerate(self.conv_block): + self.add_module(str(idx), module) + + def build_conv_block(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [ + MergeModule( + nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(x_dim, z_dim) + ), + nn.ReLU(True) + ] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + InstanceNorm2d(x_dim, affine=True)] + + return TwoInputSequential(*conv_block) + + def forward(self, x, noise): + out = self.conv_block(x, noise) + out = self.relu(x + out) + return out + +###################################################################### +# Define a resnet block +###################################################################### +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [nn.ReLU(True)] + + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = self.conv_block(x) + out = self.relu(x + out) + return out diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/frequency_model/common_freq.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/frequency_model/common_freq.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6c78467db1391f6475d069dafda7f295b97ae1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/frequency_model/common_freq.py @@ -0,0 +1,391 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange + + +class Scale(nn.Module): + + def __init__(self, init_value=1e-3): + super().__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale + + +class invPixelShuffle(nn.Module): + def __init__(self, ratio=2): + super(invPixelShuffle, self).__init__() + self.ratio = ratio + + def forward(self, tensor): + ratio = self.ratio + b = tensor.size(0) + ch = tensor.size(1) + y = tensor.size(2) + x = tensor.size(3) + assert x % ratio == 0 and y % ratio == 0, 'x, y, ratio : {}, {}, {}'.format(x, y, ratio) + tensor = tensor.view(b, ch, y // ratio, ratio, x // ratio, ratio).permute(0, 1, 3, 5, 2, 4) + return tensor.contiguous().view(b, -1, y // ratio, x // ratio) + + +class UpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(in_channels=n_feats, out_channels=4 * n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PixelShuffle(upscale_factor=2)) + m.append(nn.PReLU()) + super(UpSampler, self).__init__(*m) + + +class InvUpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(invPixelShuffle(2)) + m.append(nn.Conv2d(in_channels=n_feats * 4, out_channels=n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PReLU()) + super(InvUpSampler, self).__init__(*m) + + + +class AdaptiveNorm(nn.Module): + def __init__(self, n): + super(AdaptiveNorm, self).__init__() + + self.w_0 = nn.Parameter(torch.Tensor([1.0])) + self.w_1 = nn.Parameter(torch.Tensor([0.0])) + + self.bn = nn.BatchNorm2d(n, momentum=0.999, eps=0.001) + + def forward(self, x): + return self.w_0 * x + self.w_1 * self.bn(x) + + +class ConvBNReLU2D(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, + bias=False, act=None, norm=None): + super(ConvBNReLU2D, self).__init__() + + self.layers = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + self.act = None + self.norm = None + if norm == 'BN': + self.norm = torch.nn.BatchNorm2d(out_channels) + elif norm == 'IN': + self.norm = torch.nn.InstanceNorm2d(out_channels) + elif norm == 'GN': + self.norm = torch.nn.GroupNorm(2, out_channels) + elif norm == 'WN': + self.layers = torch.nn.utils.weight_norm(self.layers) + elif norm == 'Adaptive': + self.norm = AdaptiveNorm(n=out_channels) + + if act == 'PReLU': + self.act = torch.nn.PReLU() + elif act == 'SELU': + self.act = torch.nn.SELU(True) + elif act == 'LeakyReLU': + self.act = torch.nn.LeakyReLU(negative_slope=0.02, inplace=True) + elif act == 'ELU': + self.act = torch.nn.ELU(inplace=True) + elif act == 'ReLU': + self.act = torch.nn.ReLU(True) + elif act == 'Tanh': + self.act = torch.nn.Tanh() + elif act == 'Sigmoid': + self.act = torch.nn.Sigmoid() + elif act == 'SoftMax': + self.act = torch.nn.Softmax2d() + + def forward(self, inputs): + + out = self.layers(inputs) + + if self.norm is not None: + out = self.norm(out) + if self.act is not None: + out = self.act(out) + return out + + +class ResBlock(nn.Module): + def __init__(self, num_features): + super(ResBlock, self).__init__() + self.layers = nn.Sequential( + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, act='ReLU', padding=1), + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, padding=1) + ) + + def forward(self, inputs): + return F.relu(self.layers(inputs) + inputs) + + +class DownSample(nn.Module): + def __init__(self, num_features, act, norm, scale=2): + super(DownSample, self).__init__() + if scale == 1: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + else: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + invPixelShuffle(ratio=scale), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + + def forward(self, inputs): + return self.layers(inputs) + + +class ResidualGroup(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act,norm, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [ + ResBlock(n_feat) for _ in range(n_resblocks)] + + modules_body.append(ConvBNReLU2D(n_feat, n_feat, kernel_size, padding=1, act=act, norm=norm)) + self.body = nn.Sequential(*modules_body) + self.re_scale = Scale(1) + + def forward(self, x): + res = self.body(x) + return res + self.re_scale(x) + + + +class FreBlock9(nn.Module): + def __init__(self, channels): + super(FreBlock9, self).__init__() + + self.fpre = nn.Conv2d(channels, channels, 1, 1, 0) + self.amp_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.pha_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.post = nn.Conv2d(channels, channels, 1, 1, 0) + + + def forward(self, x): + # print("x: ", x.shape) + _, _, H, W = x.shape + msF = torch.fft.rfft2(self.fpre(x)+1e-8, norm='backward') + + msF_amp = torch.abs(msF) + msF_pha = torch.angle(msF) + # print("msf_amp: ", msF_amp.shape) + amp_fuse = self.amp_fuse(msF_amp) + # print(amp_fuse.shape, msF_amp.shape) + amp_fuse = amp_fuse + msF_amp + pha_fuse = self.pha_fuse(msF_pha) + pha_fuse = pha_fuse + msF_pha + + real = amp_fuse * torch.cos(pha_fuse)+1e-8 + imag = amp_fuse * torch.sin(pha_fuse)+1e-8 + out = torch.complex(real, imag)+1e-8 + out = torch.abs(torch.fft.irfft2(out, s=(H, W), norm='backward')) + out = self.post(out) + out = out + x + out = torch.nan_to_num(out, nan=1e-5, posinf=1e-5, neginf=1e-5) + # print("out: ", out.shape) + return out + +class Attention(nn.Module): + def __init__(self, dim=64, num_heads=8, bias=False): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) + self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias) + self.q = nn.Conv2d(dim, dim , kernel_size=1, bias=bias) + self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x, y): + b, c, h, w = x.shape + + kv = self.kv_dwconv(self.kv(y)) + k, v = kv.chunk(2, dim=1) + q = self.q_dwconv(self.q(x)) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + +class FuseBlock7(nn.Module): + def __init__(self, channels): + super(FuseBlock7, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fre_att = Attention(dim=channels) + self.spa_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + ori = spa + fre = self.fre(fre) + spa = self.spa(spa) + fre = self.fre_att(fre, spa)+fre + spa = self.fre_att(spa, fre)+spa + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class FuseBlock6(nn.Module): + def __init__(self, channels): + super(FuseBlock6, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock7(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock7, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t1_att = Attention(dim=channels) + self.t2_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + t1 = self.t1_att(t1, t2)+t1 + t2 = self.t2_att(t2, t1)+t2 + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock6(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock6, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/frequency_model/model.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/frequency_model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..2fd1e66946b7305f6e2228c92dab23da2c545dfe --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/frequency_model/model.py @@ -0,0 +1,376 @@ +import torch +from torch import nn +from . import common_freq as common + + +class TwoBranch(nn.Module): + def __init__(self, num_features, act, base_num_every_group, num_channels): + super(TwoBranch, self).__init__() + + self.num_features = num_features + self.act = act + self.num_channels = num_channels + + num_group = 4 + num_every_group = base_num_every_group + + self.init_T2_frq_branch() + self.init_T2_spa_branch( num_every_group) + self.init_T2_fre_spa_fusion() + + self.init_T1_frq_branch() + self.init_T1_spa_branch( num_every_group) + + self.init_modality_fre_fusion() + self.init_modality_spa_fusion() + + + def init_T2_frq_branch(self, ): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head_fre = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + + self.down1_fre = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo = nn.Sequential(common.FreBlock9(self.num_features)) + + modules_down2_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down2_fre = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo = nn.Sequential(common.FreBlock9(self.num_features)) + + modules_down3_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down3_fre = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo = nn.Sequential(common.FreBlock9(self.num_features)) + + modules_neck_fre = [common.FreBlock9(self.num_features, ) + ] + self.neck_fre = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo = nn.Sequential(common.FreBlock9(self.num_features )) + + modules_up1_fre = [common.UpSampler(2, self.num_features), + common.FreBlock9(self.num_features, ) + ] + self.up1_fre = nn.Sequential(*modules_up1_fre) + self.up1_fre_mo = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_up2_fre = [common.UpSampler(2, self.num_features), + common.FreBlock9(self.num_features, ) + ] + self.up2_fre = nn.Sequential(*modules_up2_fre) + self.up2_fre_mo = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_up3_fre = [common.UpSampler(2, self.num_features), + common.FreBlock9(self.num_features, ) + ] + self.up3_fre = nn.Sequential(*modules_up3_fre) + self.up3_fre_mo = nn.Sequential(common.FreBlock9(self.num_features, )) + + # define tail module + modules_tail_fre = [ + common.ConvBNReLU2D(self.num_features, out_channels=self.num_channels, kernel_size=3, padding=1, + act=self.act)] + self.tail_fre = nn.Sequential(*modules_tail_fre) + + def init_T2_spa_branch(self, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down1 = nn.Sequential(*modules_down1) + + + self.down1_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down2 = nn.Sequential(*modules_down2) + + self.down2_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down3 = nn.Sequential(*modules_down3) + self.down3_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.neck = nn.Sequential(*modules_neck) + + self.neck_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_up1 = [common.UpSampler(2, self.num_features), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.up1 = nn.Sequential(*modules_up1) + + self.up1_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_up2 = [common.UpSampler(2, self.num_features), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.up2 = nn.Sequential(*modules_up2) + self.up2_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + + modules_up3 = [common.UpSampler(2, self.num_features), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.up3 = nn.Sequential(*modules_up3) + self.up3_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + # define tail module + modules_tail = [ + common.ConvBNReLU2D(self.num_features, out_channels=self.num_channels, kernel_size=3, padding=1, + act=self.act)] + + self.tail = nn.Sequential(*modules_tail) + + def init_T2_fre_spa_fusion(self, ): + ### T2 frq & spa fusion part + conv_fuse = [] + for i in range(14): + conv_fuse.append(common.FuseBlock7(self.num_features)) + self.conv_fuse = nn.Sequential(*conv_fuse) + + def init_T1_frq_branch(self, ): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head_fre_T1 = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + + self.down1_fre_T1 = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_down2_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down2_fre_T1 = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_down3_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down3_fre_T1 = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_neck_fre = [common.FreBlock9(self.num_features, ) + ] + self.neck_fre_T1 = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + def init_T1_spa_branch(self, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head_T1 = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down1_T1 = nn.Sequential(*modules_down1) + + + self.down1_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down2_T1 = nn.Sequential(*modules_down2) + + self.down2_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down3_T1 = nn.Sequential(*modules_down3) + self.down3_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.neck_T1 = nn.Sequential(*modules_neck) + + self.neck_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + + def init_modality_fre_fusion(self, ): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(self.num_features)) + self.conv_fuse_fre = nn.Sequential(*conv_fuse) + + def init_modality_spa_fusion(self, ): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(self.num_features)) + self.conv_fuse_spa = nn.Sequential(*conv_fuse) + + def forward(self, main, aux, t): + #### T1 fre encoder + t1_fre = self.head_fre_T1(aux) # 128 + + down1_fre_t1 = self.down1_fre_T1(t1_fre)# 64 + down1_fre_mo_t1 = self.down1_fre_mo_T1(down1_fre_t1) + + down2_fre_t1 = self.down2_fre_T1(down1_fre_mo_t1) # 32 + down2_fre_mo_t1 = self.down2_fre_mo_T1(down2_fre_t1) + + down3_fre_t1 = self.down3_fre_T1(down2_fre_mo_t1) # 16 + down3_fre_mo_t1 = self.down3_fre_mo_T1(down3_fre_t1) + + neck_fre_t1 = self.neck_fre_T1(down3_fre_mo_t1) # 16 + neck_fre_mo_t1 = self.neck_fre_mo_T1(neck_fre_t1) + + + #### T2 fre encoder and T1 & T2 fre fusion + x_fre = self.head_fre(main) # 128 + x_fre_fuse = self.conv_fuse_fre[0](t1_fre, x_fre) + + down1_fre = self.down1_fre(x_fre_fuse)# 64 + down1_fre_mo = self.down1_fre_mo(down1_fre) + down1_fre_mo_fuse = self.conv_fuse_fre[1](down1_fre_mo_t1, down1_fre_mo) + + down2_fre = self.down2_fre(down1_fre_mo_fuse) # 32 + down2_fre_mo = self.down2_fre_mo(down2_fre) + down2_fre_mo_fuse = self.conv_fuse_fre[2](down2_fre_mo_t1, down2_fre_mo) + + down3_fre = self.down3_fre(down2_fre_mo_fuse) # 16 + down3_fre_mo = self.down3_fre_mo(down3_fre) + down3_fre_mo_fuse = self.conv_fuse_fre[3](down3_fre_mo_t1, down3_fre_mo) + + neck_fre = self.neck_fre(down3_fre_mo_fuse) # 16 + neck_fre_mo = self.neck_fre_mo(neck_fre) + neck_fre_mo_fuse = self.conv_fuse_fre[4](neck_fre_mo_t1, neck_fre_mo) + + + #### T2 fre decoder + neck_fre_mo = neck_fre_mo_fuse + down3_fre_mo_fuse + + up1_fre = self.up1_fre(neck_fre_mo) # 32 + up1_fre_mo = self.up1_fre_mo(up1_fre) + up1_fre_mo = up1_fre_mo + down2_fre_mo_fuse + + up2_fre = self.up2_fre(up1_fre_mo) # 64 + up2_fre_mo = self.up2_fre_mo(up2_fre) + up2_fre_mo = up2_fre_mo + down1_fre_mo_fuse + + up3_fre = self.up3_fre(up2_fre_mo) # 128 + up3_fre_mo = self.up3_fre_mo(up3_fre) + up3_fre_mo = up3_fre_mo + x_fre_fuse + + res_fre = self.tail_fre(up3_fre_mo) + + #### T1 spa encoder + x_t1 = self.head_T1(aux) # 128 + + down1_t1 = self.down1_T1(x_t1) # 64 + down1_mo_t1 = self.down1_mo_T1(down1_t1) + + down2_t1 = self.down2_T1(down1_mo_t1) # 32 + down2_mo_t1 = self.down2_mo_T1(down2_t1) # 32 + + down3_t1 = self.down3_T1(down2_mo_t1) # 16 + down3_mo_t1 = self.down3_mo_T1(down3_t1) # 16 + + neck_t1 = self.neck_T1(down3_mo_t1) # 16 + neck_mo_t1 = self.neck_mo_T1(neck_t1) + + #### T2 spa encoder and fusion + x = self.head(main) # 128 + + x_fuse = self.conv_fuse_spa[0](x_t1, x) + down1 = self.down1(x_fuse) # 64 + down1_fuse = self.conv_fuse[0](down1_fre, down1) + down1_mo = self.down1_mo(down1_fuse) + down1_fuse_mo = self.conv_fuse[1](down1_fre_mo_fuse, down1_mo) + + down1_fuse_mo_fuse = self.conv_fuse_spa[1](down1_mo_t1, down1_fuse_mo) + down2 = self.down2(down1_fuse_mo_fuse) # 32 + down2_fuse = self.conv_fuse[2](down2_fre, down2) + down2_mo = self.down2_mo(down2_fuse) # 32 + down2_fuse_mo = self.conv_fuse[3](down2_fre_mo, down2_mo) + + down2_fuse_mo_fuse = self.conv_fuse_spa[2](down2_mo_t1, down2_fuse_mo) + down3 = self.down3(down2_fuse_mo_fuse) # 16 + down3_fuse = self.conv_fuse[4](down3_fre, down3) + down3_mo = self.down3_mo(down3_fuse) # 16 + down3_fuse_mo = self.conv_fuse[5](down3_fre_mo, down3_mo) + + down3_fuse_mo_fuse = self.conv_fuse_spa[3](down3_mo_t1, down3_fuse_mo) + neck = self.neck(down3_fuse_mo_fuse) # 16 + neck_fuse = self.conv_fuse[6](neck_fre, neck) + neck_mo = self.neck_mo(neck_fuse) + neck_mo = neck_mo + down3_mo + neck_fuse_mo = self.conv_fuse[7](neck_fre_mo, neck_mo) + + neck_fuse_mo_fuse = self.conv_fuse_spa[4](neck_mo_t1, neck_fuse_mo) + #### T2 spa decoder + up1 = self.up1(neck_fuse_mo_fuse) # 32 + up1_fuse = self.conv_fuse[8](up1_fre, up1) + up1_mo = self.up1_mo(up1_fuse) + up1_mo = up1_mo + down2_mo + up1_fuse_mo = self.conv_fuse[9](up1_fre_mo, up1_mo) + + up2 = self.up2(up1_fuse_mo) # 64 + up2_fuse = self.conv_fuse[10](up2_fre, up2) + up2_mo = self.up2_mo(up2_fuse) + up2_mo = up2_mo + down1_mo + up2_fuse_mo = self.conv_fuse[11](up2_fre_mo, up2_mo) + + up3 = self.up3(up2_fuse_mo) # 128 + + up3_fuse = self.conv_fuse[12](up3_fre, up3) + up3_mo = self.up3_mo(up3_fuse) + + up3_mo = up3_mo + x + up3_fuse_mo = self.conv_fuse[13](up3_fre_mo, up3_mo) + + res = self.tail(up3_fuse_mo) + + return res + main, res_fre + main + + + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/networks_fsm/__init__.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/networks_fsm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/networks_fsm/common_freq.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/networks_fsm/common_freq.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6c78467db1391f6475d069dafda7f295b97ae1 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/networks_fsm/common_freq.py @@ -0,0 +1,391 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange + + +class Scale(nn.Module): + + def __init__(self, init_value=1e-3): + super().__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale + + +class invPixelShuffle(nn.Module): + def __init__(self, ratio=2): + super(invPixelShuffle, self).__init__() + self.ratio = ratio + + def forward(self, tensor): + ratio = self.ratio + b = tensor.size(0) + ch = tensor.size(1) + y = tensor.size(2) + x = tensor.size(3) + assert x % ratio == 0 and y % ratio == 0, 'x, y, ratio : {}, {}, {}'.format(x, y, ratio) + tensor = tensor.view(b, ch, y // ratio, ratio, x // ratio, ratio).permute(0, 1, 3, 5, 2, 4) + return tensor.contiguous().view(b, -1, y // ratio, x // ratio) + + +class UpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(in_channels=n_feats, out_channels=4 * n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PixelShuffle(upscale_factor=2)) + m.append(nn.PReLU()) + super(UpSampler, self).__init__(*m) + + +class InvUpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(invPixelShuffle(2)) + m.append(nn.Conv2d(in_channels=n_feats * 4, out_channels=n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PReLU()) + super(InvUpSampler, self).__init__(*m) + + + +class AdaptiveNorm(nn.Module): + def __init__(self, n): + super(AdaptiveNorm, self).__init__() + + self.w_0 = nn.Parameter(torch.Tensor([1.0])) + self.w_1 = nn.Parameter(torch.Tensor([0.0])) + + self.bn = nn.BatchNorm2d(n, momentum=0.999, eps=0.001) + + def forward(self, x): + return self.w_0 * x + self.w_1 * self.bn(x) + + +class ConvBNReLU2D(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, + bias=False, act=None, norm=None): + super(ConvBNReLU2D, self).__init__() + + self.layers = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + self.act = None + self.norm = None + if norm == 'BN': + self.norm = torch.nn.BatchNorm2d(out_channels) + elif norm == 'IN': + self.norm = torch.nn.InstanceNorm2d(out_channels) + elif norm == 'GN': + self.norm = torch.nn.GroupNorm(2, out_channels) + elif norm == 'WN': + self.layers = torch.nn.utils.weight_norm(self.layers) + elif norm == 'Adaptive': + self.norm = AdaptiveNorm(n=out_channels) + + if act == 'PReLU': + self.act = torch.nn.PReLU() + elif act == 'SELU': + self.act = torch.nn.SELU(True) + elif act == 'LeakyReLU': + self.act = torch.nn.LeakyReLU(negative_slope=0.02, inplace=True) + elif act == 'ELU': + self.act = torch.nn.ELU(inplace=True) + elif act == 'ReLU': + self.act = torch.nn.ReLU(True) + elif act == 'Tanh': + self.act = torch.nn.Tanh() + elif act == 'Sigmoid': + self.act = torch.nn.Sigmoid() + elif act == 'SoftMax': + self.act = torch.nn.Softmax2d() + + def forward(self, inputs): + + out = self.layers(inputs) + + if self.norm is not None: + out = self.norm(out) + if self.act is not None: + out = self.act(out) + return out + + +class ResBlock(nn.Module): + def __init__(self, num_features): + super(ResBlock, self).__init__() + self.layers = nn.Sequential( + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, act='ReLU', padding=1), + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, padding=1) + ) + + def forward(self, inputs): + return F.relu(self.layers(inputs) + inputs) + + +class DownSample(nn.Module): + def __init__(self, num_features, act, norm, scale=2): + super(DownSample, self).__init__() + if scale == 1: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + else: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + invPixelShuffle(ratio=scale), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + + def forward(self, inputs): + return self.layers(inputs) + + +class ResidualGroup(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act,norm, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [ + ResBlock(n_feat) for _ in range(n_resblocks)] + + modules_body.append(ConvBNReLU2D(n_feat, n_feat, kernel_size, padding=1, act=act, norm=norm)) + self.body = nn.Sequential(*modules_body) + self.re_scale = Scale(1) + + def forward(self, x): + res = self.body(x) + return res + self.re_scale(x) + + + +class FreBlock9(nn.Module): + def __init__(self, channels): + super(FreBlock9, self).__init__() + + self.fpre = nn.Conv2d(channels, channels, 1, 1, 0) + self.amp_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.pha_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.post = nn.Conv2d(channels, channels, 1, 1, 0) + + + def forward(self, x): + # print("x: ", x.shape) + _, _, H, W = x.shape + msF = torch.fft.rfft2(self.fpre(x)+1e-8, norm='backward') + + msF_amp = torch.abs(msF) + msF_pha = torch.angle(msF) + # print("msf_amp: ", msF_amp.shape) + amp_fuse = self.amp_fuse(msF_amp) + # print(amp_fuse.shape, msF_amp.shape) + amp_fuse = amp_fuse + msF_amp + pha_fuse = self.pha_fuse(msF_pha) + pha_fuse = pha_fuse + msF_pha + + real = amp_fuse * torch.cos(pha_fuse)+1e-8 + imag = amp_fuse * torch.sin(pha_fuse)+1e-8 + out = torch.complex(real, imag)+1e-8 + out = torch.abs(torch.fft.irfft2(out, s=(H, W), norm='backward')) + out = self.post(out) + out = out + x + out = torch.nan_to_num(out, nan=1e-5, posinf=1e-5, neginf=1e-5) + # print("out: ", out.shape) + return out + +class Attention(nn.Module): + def __init__(self, dim=64, num_heads=8, bias=False): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) + self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias) + self.q = nn.Conv2d(dim, dim , kernel_size=1, bias=bias) + self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x, y): + b, c, h, w = x.shape + + kv = self.kv_dwconv(self.kv(y)) + k, v = kv.chunk(2, dim=1) + q = self.q_dwconv(self.q(x)) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + +class FuseBlock7(nn.Module): + def __init__(self, channels): + super(FuseBlock7, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fre_att = Attention(dim=channels) + self.spa_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + ori = spa + fre = self.fre(fre) + spa = self.spa(spa) + fre = self.fre_att(fre, spa)+fre + spa = self.fre_att(spa, fre)+spa + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class FuseBlock6(nn.Module): + def __init__(self, channels): + super(FuseBlock6, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock7(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock7, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t1_att = Attention(dim=channels) + self.t2_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + t1 = self.t1_att(t1, t2)+t1 + t2 = self.t2_att(t2, t1)+t2 + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock6(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock6, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/networks_fsm/modules.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/networks_fsm/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..bad92a6d6709130c4ad1cc49d5094958d44d7e71 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/networks_fsm/modules.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + +USE_PYTORCH_IN = False + + +###################################################################### +# Superclass of all Modules that take two inputs +###################################################################### +class TwoInputModule(nn.Module): + def forward(self, input1, input2): + raise NotImplementedError + +###################################################################### +# A (sort of) hacky way to create a module that takes two inputs (e.g. x and z) +# and returns one output (say o) defined as follows: +# o = module2.forward(module1.forward(x), z) +# Note that module2 MUST support two inputs as well. +###################################################################### +class MergeModule(TwoInputModule): + def __init__(self, module1, module2): + """ module1 could be any module (e.g. Sequential of several modules) + module2 must accept two inputs + """ + super(MergeModule, self).__init__() + self.module1 = module1 + self.module2 = module2 + + def forward(self, input1, input2): + output1 = self.module1.forward(input1) + output2 = self.module2.forward(output1, input2) + return output2 + +###################################################################### +# A (sort of) hacky way to create a container that takes two inputs (e.g. x and z) +# and applies a sequence of modules (exactly like nn.Sequential) but MergeModule +# is one of its submodules it applies it to both inputs +###################################################################### +class TwoInputSequential(nn.Sequential, TwoInputModule): + def __init__(self, *args): + super(TwoInputSequential, self).__init__(*args) + + def forward(self, input1, input2): + """overloads forward function in parent calss""" + + for module in self._modules.values(): + if isinstance(module, TwoInputModule): + input1 = module.forward(input1, input2) + else: + input1 = module.forward(input1) + return input1 + + +###################################################################### +# A standard instance norm module. +# Since the pytorch instance norm used BatchNorm as a base and thus is +# different from the standard implementation. +###################################################################### +class InstanceNorm(nn.Module): + def __init__(self, num_features, affine=True, eps=1e-5): + """`num_features` number of feature channels + """ + super(InstanceNorm, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + + self.scale = Parameter(torch.Tensor(num_features)) + self.shift = Parameter(torch.Tensor(num_features)) + + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + self.scale.data.normal_(mean=0., std=0.02) + self.shift.data.zero_() + + def forward(self, input): + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + centered_x = x_reshaped - mean + std = torch.rsqrt((centered_x ** 2).mean(2, keepdim=True) + self.eps) + norm_features = (centered_x * std).view(*size) + + # broadcast on the batch dimension, hight and width dimensions + if self.affine: + output = norm_features * self.scale[:,None,None] + self.shift[:,None,None] + else: + output = norm_features + + return output + +InstanceNorm2d = nn.InstanceNorm2d if USE_PYTORCH_IN else InstanceNorm + +###################################################################### +# A module implementing conditional instance norm. +# Takes two inputs: x (input features) and z (latent codes) +###################################################################### +class CondInstanceNorm(TwoInputModule): + def __init__(self, x_dim, z_dim, eps=1e-5): + """`x_dim` dimensionality of x input + `z_dim` dimensionality of z latents + """ + super(CondInstanceNorm, self).__init__() + self.eps = eps + self.shift_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + self.scale_conv = nn.Sequential( + nn.Conv2d(z_dim, x_dim, kernel_size=1, padding=0, bias=True), + nn.ReLU(True) + ) + + def forward(self, input, noise): + + shift = self.shift_conv.forward(noise) + scale = self.scale_conv.forward(noise) + size = input.size() + x_reshaped = input.view(size[0], size[1], size[2]*size[3]) + mean = x_reshaped.mean(2, keepdim=True) + var = x_reshaped.var(2, keepdim=True) + std = torch.rsqrt(var + self.eps) + norm_features = ((x_reshaped - mean) * std).view(*size) + output = norm_features * scale + shift + return output + + +###################################################################### +# A modified resnet block which allows for passing additional noise input +# to be used for conditional instance norm +###################################################################### +class CINResnetBlock(TwoInputModule): + def __init__(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + super(CINResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + for idx, module in enumerate(self.conv_block): + self.add_module(str(idx), module) + + def build_conv_block(self, x_dim, z_dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [ + MergeModule( + nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(x_dim, z_dim) + ), + nn.ReLU(True) + ] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(x_dim, x_dim, kernel_size=3, padding=p, bias=use_bias), + InstanceNorm2d(x_dim, affine=True)] + + return TwoInputSequential(*conv_block) + + def forward(self, x, noise): + out = self.conv_block(x, noise) + out = self.relu(x + out) + return out + +###################################################################### +# Define a resnet block +###################################################################### +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + self.relu = nn.ReLU(True) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [nn.ReLU(True)] + + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] + conv_block += [norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = self.conv_block(x) + out = self.relu(x + out) + return out diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/networks_fsm/mynet.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/networks_fsm/mynet.py new file mode 100644 index 0000000000000000000000000000000000000000..6f3734afef3168a620e49722fb3e1a3e64708c42 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/networks_fsm/mynet.py @@ -0,0 +1,376 @@ +import torch +from torch import nn +from . import common_freq as common + + +class TwoBranch(nn.Module): + def __init__(self, num_features, act, base_num_every_group, num_channels): + super(TwoBranch, self).__init__() + + self.num_features = num_features + self.act = act + self.num_channels = num_channels + print("num_channels: ", num_channels) + + num_group = 4 + num_every_group = base_num_every_group + + self.init_T2_frq_branch() + self.init_T2_spa_branch( num_every_group) + self.init_T2_fre_spa_fusion() + + self.init_T1_frq_branch() + self.init_T1_spa_branch( num_every_group) + + self.init_modality_fre_fusion() + self.init_modality_spa_fusion() + + + def init_T2_frq_branch(self, ): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head_fre = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + + self.down1_fre = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_down2_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down2_fre = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo = nn.Sequential(common.FreBlock9(self.num_features)) + + modules_down3_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down3_fre = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo = nn.Sequential(common.FreBlock9(self.num_features)) + + modules_neck_fre = [common.FreBlock9(self.num_features, ) + ] + self.neck_fre = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo = nn.Sequential(common.FreBlock9(self.num_features )) + + modules_up1_fre = [common.UpSampler(2, self.num_features), + common.FreBlock9(self.num_features, ) + ] + self.up1_fre = nn.Sequential(*modules_up1_fre) + self.up1_fre_mo = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_up2_fre = [common.UpSampler(2, self.num_features), + common.FreBlock9(self.num_features, ) + ] + self.up2_fre = nn.Sequential(*modules_up2_fre) + self.up2_fre_mo = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_up3_fre = [common.UpSampler(2, self.num_features), + common.FreBlock9(self.num_features, ) + ] + self.up3_fre = nn.Sequential(*modules_up3_fre) + self.up3_fre_mo = nn.Sequential(common.FreBlock9(self.num_features, )) + + # define tail module + modules_tail_fre = [ + common.ConvBNReLU2D(self.num_features, out_channels=self.num_channels, kernel_size=3, padding=1, + act=self.act)] + self.tail_fre = nn.Sequential(*modules_tail_fre) + + def init_T2_spa_branch(self, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down1 = nn.Sequential(*modules_down1) + + + self.down1_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down2 = nn.Sequential(*modules_down2) + + self.down2_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down3 = nn.Sequential(*modules_down3) + self.down3_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.neck = nn.Sequential(*modules_neck) + + self.neck_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_up1 = [common.UpSampler(2, self.num_features), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.up1 = nn.Sequential(*modules_up1) + + self.up1_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_up2 = [common.UpSampler(2, self.num_features), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.up2 = nn.Sequential(*modules_up2) + self.up2_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + + modules_up3 = [common.UpSampler(2, self.num_features), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.up3 = nn.Sequential(*modules_up3) + self.up3_mo = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + # define tail module + modules_tail = [ + common.ConvBNReLU2D(self.num_features, out_channels=self.num_channels, kernel_size=3, padding=1, + act=self.act)] + + self.tail = nn.Sequential(*modules_tail) + + def init_T2_fre_spa_fusion(self, ): + ### T2 frq & spa fusion part + conv_fuse = [] + for i in range(14): + conv_fuse.append(common.FuseBlock7(self.num_features)) + self.conv_fuse = nn.Sequential(*conv_fuse) + + def init_T1_frq_branch(self, ): + ### T2frequency branch + modules_head_fre = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head_fre_T1 = nn.Sequential(*modules_head_fre) + + modules_down1_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + + self.down1_fre_T1 = nn.Sequential(*modules_down1_fre) + self.down1_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_down2_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down2_fre_T1 = nn.Sequential(*modules_down2_fre) + + self.down2_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_down3_fre = [common.DownSample(self.num_features, False, False), + common.FreBlock9(self.num_features, ) + ] + self.down3_fre_T1 = nn.Sequential(*modules_down3_fre) + self.down3_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + modules_neck_fre = [common.FreBlock9(self.num_features, ) + ] + self.neck_fre_T1 = nn.Sequential(*modules_neck_fre) + self.neck_fre_mo_T1 = nn.Sequential(common.FreBlock9(self.num_features, )) + + def init_T1_spa_branch(self, num_every_group): + ### spatial branch + modules_head = [common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act)] + self.head_T1 = nn.Sequential(*modules_head) + + modules_down1 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down1_T1 = nn.Sequential(*modules_down1) + + + self.down1_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down2 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down2_T1 = nn.Sequential(*modules_down2) + + self.down2_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_down3 = [common.DownSample(self.num_features, False, False), + common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.down3_T1 = nn.Sequential(*modules_down3) + self.down3_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + modules_neck = [common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None) + ] + self.neck_T1 = nn.Sequential(*modules_neck) + + self.neck_mo_T1 = nn.Sequential(common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None)) + + + def init_modality_fre_fusion(self, ): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(self.num_features)) + self.conv_fuse_fre = nn.Sequential(*conv_fuse) + + def init_modality_spa_fusion(self, ): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(self.num_features)) + self.conv_fuse_spa = nn.Sequential(*conv_fuse) + + def forward(self, main, aux, t): + #### T1 fre encoder + t1_fre = self.head_fre_T1(aux) # 128 + + down1_fre_t1 = self.down1_fre_T1(t1_fre)# 64 + down1_fre_mo_t1 = self.down1_fre_mo_T1(down1_fre_t1) + + down2_fre_t1 = self.down2_fre_T1(down1_fre_mo_t1) # 32 + down2_fre_mo_t1 = self.down2_fre_mo_T1(down2_fre_t1) + + down3_fre_t1 = self.down3_fre_T1(down2_fre_mo_t1) # 16 + down3_fre_mo_t1 = self.down3_fre_mo_T1(down3_fre_t1) + + neck_fre_t1 = self.neck_fre_T1(down3_fre_mo_t1) # 16 + neck_fre_mo_t1 = self.neck_fre_mo_T1(neck_fre_t1) + + + #### T2 fre encoder and T1 & T2 fre fusion + x_fre = self.head_fre(main) # 128 + x_fre_fuse = self.conv_fuse_fre[0](t1_fre, x_fre) + + down1_fre = self.down1_fre(x_fre_fuse)# 64 + down1_fre_mo = self.down1_fre_mo(down1_fre) + down1_fre_mo_fuse = self.conv_fuse_fre[1](down1_fre_mo_t1, down1_fre_mo) + + down2_fre = self.down2_fre(down1_fre_mo_fuse) # 32 + down2_fre_mo = self.down2_fre_mo(down2_fre) + down2_fre_mo_fuse = self.conv_fuse_fre[2](down2_fre_mo_t1, down2_fre_mo) + + down3_fre = self.down3_fre(down2_fre_mo_fuse) # 16 + down3_fre_mo = self.down3_fre_mo(down3_fre) + down3_fre_mo_fuse = self.conv_fuse_fre[3](down3_fre_mo_t1, down3_fre_mo) + + neck_fre = self.neck_fre(down3_fre_mo_fuse) # 16 + neck_fre_mo = self.neck_fre_mo(neck_fre) + neck_fre_mo_fuse = self.conv_fuse_fre[4](neck_fre_mo_t1, neck_fre_mo) + + + #### T2 fre decoder + neck_fre_mo = neck_fre_mo_fuse + down3_fre_mo_fuse + + up1_fre = self.up1_fre(neck_fre_mo) # 32 + up1_fre_mo = self.up1_fre_mo(up1_fre) + up1_fre_mo = up1_fre_mo + down2_fre_mo_fuse + + up2_fre = self.up2_fre(up1_fre_mo) # 64 + up2_fre_mo = self.up2_fre_mo(up2_fre) + up2_fre_mo = up2_fre_mo + down1_fre_mo_fuse + + up3_fre = self.up3_fre(up2_fre_mo) # 128 + up3_fre_mo = self.up3_fre_mo(up3_fre) + up3_fre_mo = up3_fre_mo + x_fre_fuse + + res_fre = self.tail_fre(up3_fre_mo) + + #### T1 spa encoder + x_t1 = self.head_T1(aux) # 128 + + down1_t1 = self.down1_T1(x_t1) # 64 + down1_mo_t1 = self.down1_mo_T1(down1_t1) + + down2_t1 = self.down2_T1(down1_mo_t1) # 32 + down2_mo_t1 = self.down2_mo_T1(down2_t1) # 32 + + down3_t1 = self.down3_T1(down2_mo_t1) # 16 + down3_mo_t1 = self.down3_mo_T1(down3_t1) # 16 + + neck_t1 = self.neck_T1(down3_mo_t1) # 16 + neck_mo_t1 = self.neck_mo_T1(neck_t1) + + #### T2 spa encoder and fusion + x = self.head(main) # 128 + + x_fuse = self.conv_fuse_spa[0](x_t1, x) + down1 = self.down1(x_fuse) # 64 + down1_fuse = self.conv_fuse[0](down1_fre, down1) + down1_mo = self.down1_mo(down1_fuse) + down1_fuse_mo = self.conv_fuse[1](down1_fre_mo_fuse, down1_mo) + + down1_fuse_mo_fuse = self.conv_fuse_spa[1](down1_mo_t1, down1_fuse_mo) + down2 = self.down2(down1_fuse_mo_fuse) # 32 + down2_fuse = self.conv_fuse[2](down2_fre, down2) + down2_mo = self.down2_mo(down2_fuse) # 32 + down2_fuse_mo = self.conv_fuse[3](down2_fre_mo, down2_mo) + + down2_fuse_mo_fuse = self.conv_fuse_spa[2](down2_mo_t1, down2_fuse_mo) + down3 = self.down3(down2_fuse_mo_fuse) # 16 + down3_fuse = self.conv_fuse[4](down3_fre, down3) + down3_mo = self.down3_mo(down3_fuse) # 16 + down3_fuse_mo = self.conv_fuse[5](down3_fre_mo, down3_mo) + + down3_fuse_mo_fuse = self.conv_fuse_spa[3](down3_mo_t1, down3_fuse_mo) + neck = self.neck(down3_fuse_mo_fuse) # 16 + neck_fuse = self.conv_fuse[6](neck_fre, neck) + neck_mo = self.neck_mo(neck_fuse) + neck_mo = neck_mo + down3_mo + neck_fuse_mo = self.conv_fuse[7](neck_fre_mo, neck_mo) + + neck_fuse_mo_fuse = self.conv_fuse_spa[4](neck_mo_t1, neck_fuse_mo) + #### T2 spa decoder + up1 = self.up1(neck_fuse_mo_fuse) # 32 + up1_fuse = self.conv_fuse[8](up1_fre, up1) + up1_mo = self.up1_mo(up1_fuse) + up1_mo = up1_mo + down2_mo + up1_fuse_mo = self.conv_fuse[9](up1_fre_mo, up1_mo) + + up2 = self.up2(up1_fuse_mo) # 64 + up2_fuse = self.conv_fuse[10](up2_fre, up2) + up2_mo = self.up2_mo(up2_fuse) + up2_mo = up2_mo + down1_mo + up2_fuse_mo = self.conv_fuse[11](up2_fre_mo, up2_mo) + + up3 = self.up3(up2_fuse_mo) # 128 + + up3_fuse = self.conv_fuse[12](up3_fre, up3) + up3_mo = self.up3_mo(up3_fuse) + + up3_mo = up3_mo + x + up3_fuse_mo = self.conv_fuse[13](up3_fre_mo, up3_mo) + + res = self.tail(up3_fuse_mo) + + return res + main, res_fre + main + +def make_model(): + return TwoBranch() + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/new_twobranch_model.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/new_twobranch_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2c99525dd4649886a0b5da769016c87aa5d6245f --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/new_twobranch_model.py @@ -0,0 +1,515 @@ +import math +import torch +import torch.nn as nn + +from .st_branch_model_spa.utils import AMPLoss, PhaLoss + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class TransformerBlock(nn.Module): + def __init__(self, embed_dim, num_heads, feedforward_dim, dropout=0.1): + super(TransformerBlock, self).__init__() + self.transformer = nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=num_heads, + dim_feedforward=feedforward_dim, + dropout=dropout, + batch_first=True, + ) + + def forward(self, x): + return self.transformer(x) + +class CrossAttention(nn.Module): + def __init__(self, embed_dim, num_heads): + super(CrossAttention, self).__init__() + self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True) + + def forward(self, query, key, value): + attn_output, attn_weights = self.attention(query, key, value) + return attn_output, attn_weights + + + +class FreBlock(nn.Module): + def __init__(self, channels, embed_dim = 256): + super(FreBlock, self).__init__() + + num_heads = 8 + + self.fpre = nn.Conv2d(channels, channels, 1, 1, 0) + self.amp_conv = nn.Sequential( + nn.Conv2d(channels, channels, 5, 1, 2), + nn.LeakyReLU(0.1, inplace=True) + ) + self.pha_conv = nn.Sequential( + nn.Conv2d(channels, channels, 5, 1, 2), + nn.LeakyReLU(0.1, inplace=True) + ) + + + self.amp_fuse = nn.Sequential( + TransformerBlock(embed_dim, num_heads, embed_dim), + nn.LeakyReLU(0.1, inplace=True), + # TransformerBlock(embed_dim, num_heads, embed_dim), + # nn.ReLU() + ) + + self.pha_fuse = nn.Sequential( + TransformerBlock(embed_dim, num_heads, embed_dim), + nn.LeakyReLU(0.1, inplace=True), + # TransformerBlock(embed_dim, num_heads, embed_dim) + ) + + self.post = nn.Conv2d(channels, channels, 1, 1, 0) + + + # self.transformer = TransformerBlock(embed_dim, num_heads, feedforward_dim) + self.cross_attention = CrossAttention(embed_dim, num_heads) + self.cross_attention_2 = CrossAttention(embed_dim, num_heads) + + + def forward(self, x, k=None): + _, _, H, W = x.shape + # k shape, msF_component_fuse shape torch.Size([24, 1, 128, 128]) torch.Size([24, 256, 16, 9]) + + # rfft2 输出的形状 (半频谱): (rows, cols//2 + 1) + # half_W = W // 2 + 1 + # down-scale + # k = torch.nn.functional.interpolate(k, size=(H, W), mode='bilinear', + # align_corners=False).cuda() + # k = k[...,:half_W] + + + fpre = self.fpre(x) + msF = torch.fft. rfft2(fpre + 1e-8, norm='ortho') + msF = torch.fft.fftshift(msF, dim=[2, 3]) + + msF_ori= msF.clone() # * k + + msF_amp = torch.abs(msF) + msF_pha = torch.angle(msF) + + + msF_amp = self.amp_conv(msF_amp) + msF_pha = self.pha_conv(msF_pha) + + batch_size, channels, height, width = msF_amp.shape + msF_amp_flatten = msF_amp.view(batch_size, channels, -1).permute(0, 2, 1) # (batch_size, H*W, channels) + msF_pha_flatten = msF_pha.view(batch_size, channels, -1).permute(0, 2, 1) # (batch_size, H*W, channels) + # print("msF_amp_flatten shape", msF_amp_flatten.shape) + + # channels = msF_amp.shape[1] + msF_amp_flatten, _ = self.cross_attention( msF_amp_flatten, msF_pha_flatten, msF_pha_flatten) + msF_pha_flatten, _ = self.cross_attention_2(msF_pha_flatten, msF_amp_flatten, msF_amp_flatten) + + amplitude_features = self.amp_fuse(msF_amp_flatten) # + msF_component + angle_features = self.pha_fuse(msF_pha_flatten) # + msF_component + + # cross attention + amp_fuse = amplitude_features.permute(0, 2, 1).view(batch_size, channels, height, width) + pha_fuse = angle_features.permute(0, 2, 1).view(batch_size, channels, height, width) + + amp_fuse = nn.ReLU()(amp_fuse) + real = amp_fuse * torch.cos(pha_fuse) + 1e-8 + imag = amp_fuse * torch.sin(pha_fuse) + 1e-8 + + out = torch.complex(real, imag) + 1e-8 + out = out + msF_ori # * (1 - k) + + out = torch.fft.ifftshift(out, dim=[2, 3]) + out = torch.abs(torch.fft.irfft2(out, s=(H, W), norm='ortho')) + out = self.post(out) + + + + out = torch.nan_to_num(out, nan=1e-5, posinf=1, neginf=-1) + + return out + + +class Branch(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + + + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels * 2, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,) + ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + fre = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + fre.append(FreBlock(channels=block_in)) + down = nn.Module() + down.block = block + down.attn = attn + down.fre = fre + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + self.mid_fre = FreBlock(channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + fre = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + fre.append(FreBlock(channels=block_in)) + up = nn.Module() + up.block = block + up.attn = attn + up.fre = fre + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + + + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + self.spatial = Branch(ch=ch, out_ch=out_ch, ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + dropout=dropout, resamp_with_conv=resamp_with_conv, + in_channels=in_channels, resolution=resolution) + + self.amploss = AMPLoss() # .to(self.device, non_blocking=True) + self.phaloss = PhaLoss() # .to(self.device, non_blocking=True) + + self.use_front_fre = False + self.use_after_fre = False + print("=== use front fre", self.use_front_fre) # NAN + print("=== use after fre", self.use_after_fre) # use_after_fre_ BUG NAN + + def forward(self, x, aux, k, t): + assert x.shape[2] == x.shape[3] == self.resolution + + # k = k.to(x.device) + + # timestep embedding + temp = None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + + x_in = torch.cat((x, aux), dim=1) + + # spatial downsampling + hs = [self.spatial.conv_in(x_in)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.spatial.down[i_level].block[i_block](hs[-1], temb) + if len(self.spatial.down[i_level].attn) > 0: + if self.use_front_fre: + h = self.spatial.down[i_level].fre[i_block](h, k) + h = self.spatial.down[i_level].attn[i_block](h) + + if self.use_after_fre: + h = self.spatial.down[i_level].fre[i_block](h, k) + h + + hs.append(h) + + if i_level != self.num_resolutions-1: + hs.append(self.spatial.down[i_level].downsample(hs[-1])) + + # spatial middle + h = hs[-1] + h = self.spatial.mid.block_1(h, temb) + h = self.spatial.mid.attn_1(h) + h = self.spatial.mid.block_2(h, temb) + + # if self.use_front_fre or self.use_after_fre: + # h = self.spatial.mid_fre(h, k) # + h # NAN?? + + # spatial upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.spatial.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.spatial.up[i_level].attn) > 0: + if self.use_front_fre: + h = self.spatial.up[i_level].fre[i_block](h, k) + h = self.spatial.up[i_level].attn[i_block](h) + if self.use_after_fre: + h = self.spatial.up[i_level].fre[i_block](h, k) + h + + # TODO residual + # h += hs.pop() + + if i_level != 0: + h = self.spatial.up[i_level].upsample(h) + + # spatial end + h = self.spatial.norm_out(h) + h = nonlinearity(h) + h = self.spatial.conv_out(h) + + return h diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/st_branch_model/common_freq.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/st_branch_model/common_freq.py new file mode 100644 index 0000000000000000000000000000000000000000..f209b9da0894b884345487dde2ebe9344ca131f3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/st_branch_model/common_freq.py @@ -0,0 +1,456 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange +from torch.fft import * + +def frequency_transform(x_input, pixel_range='-1_1', to_frequency=True): + if to_frequency: + if pixel_range == '0_1': + pass + + elif pixel_range == '-1_1': + # x_start (-1, 1) --> (0, 1) + x_start = (x_input + 1) / 2 + + elif pixel_range == 'complex': + x_start = torch.complex(x_input[:, :1, ...], x_input[:, 1:, ...]) + + else: + raise ValueError(f"Unknown pixel range {pixel_range}.") + + fft = fftshift(fft2(x_input)) + return fft + + else: + x_ksu = ifft2(ifftshift(x_input)) + + if pixel_range == '0_1': + x_ksu = torch.abs(x_ksu) + + elif pixel_range == '-1_1': + x_ksu = torch.abs(x_ksu) + # x_ksu (0, 1) --> (-1, 1) + x_ksu = x_ksu * 2 - 1 + + elif pixel_range == 'complex': + x_ksu = torch.concat((x_ksu.real, x_ksu.imag), dim=1) + else: + raise ValueError(f"Unknown pixel range {pixel_range}.") + + return x_ksu + + +class Scale(nn.Module): + + def __init__(self, init_value=1e-3): + super().__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale + + +class invPixelShuffle(nn.Module): + def __init__(self, ratio=2): + super(invPixelShuffle, self).__init__() + self.ratio = ratio + + def forward(self, tensor): + ratio = self.ratio + b = tensor.size(0) + ch = tensor.size(1) + y = tensor.size(2) + x = tensor.size(3) + assert x % ratio == 0 and y % ratio == 0, 'x, y, ratio : {}, {}, {}'.format(x, y, ratio) + tensor = tensor.view(b, ch, y // ratio, ratio, x // ratio, ratio).permute(0, 1, 3, 5, 2, 4) + return tensor.contiguous().view(b, -1, y // ratio, x // ratio) + + +class UpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(in_channels=n_feats, out_channels=4 * n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PixelShuffle(upscale_factor=2)) + m.append(nn.PReLU()) + super(UpSampler, self).__init__(*m) + + +class InvUpSampler(nn.Sequential): + def __init__(self, scale, n_feats): + + m = [] + if scale == 8: + kernel_size = 3 + elif scale == 16: + kernel_size = 5 + else: + kernel_size = 1 + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(invPixelShuffle(2)) + m.append(nn.Conv2d(in_channels=n_feats * 4, out_channels=n_feats, kernel_size=kernel_size, stride=1, + padding=kernel_size // 2)) + m.append(nn.PReLU()) + super(InvUpSampler, self).__init__(*m) + + + +class AdaptiveNorm(nn.Module): + def __init__(self, n): + super(AdaptiveNorm, self).__init__() + + self.w_0 = nn.Parameter(torch.Tensor([1.0])) + self.w_1 = nn.Parameter(torch.Tensor([0.0])) + + self.bn = nn.BatchNorm2d(n, momentum=0.999, eps=0.001) + + def forward(self, x): + return self.w_0 * x + self.w_1 * self.bn(x) + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + +class ConvBNReLU2D(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, + bias=False, act=None, norm=None, temb_ch=None): + super(ConvBNReLU2D, self).__init__() + + self.layers = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + + self.temb_proj = None + if temb_ch != None: + self.temb_proj = torch.nn.Linear(temb_ch, + out_channels) + + + self.act = None + self.norm = None + if norm == 'BN': + self.norm = torch.nn.BatchNorm2d(out_channels) + elif norm == 'IN': + self.norm = torch.nn.InstanceNorm2d(out_channels) + elif norm == 'GN': + self.norm = torch.nn.GroupNorm(2, out_channels) + elif norm == 'WN': + self.layers = torch.nn.utils.weight_norm(self.layers) + elif norm == 'Adaptive': + self.norm = AdaptiveNorm(n=out_channels) + + if act == 'PReLU': + self.act = torch.nn.PReLU() + elif act == 'SELU': + self.act = torch.nn.SELU(True) + elif act == 'LeakyReLU': + self.act = torch.nn.LeakyReLU(negative_slope=0.02, inplace=True) + elif act == 'ELU': + self.act = torch.nn.ELU(inplace=True) + elif act == 'ReLU': + self.act = torch.nn.ReLU(True) + elif act == 'Tanh': + self.act = torch.nn.Tanh() + elif act == 'Sigmoid': + self.act = torch.nn.Sigmoid() + elif act == 'SoftMax': + self.act = torch.nn.Softmax2d() + + def forward(self, inputs, temp=None): + + out = self.layers(inputs) + + if self.temb_proj != None and temp !=None: + out = out + self.temb_proj(nonlinearity(temp))[:, :, None, None] + + if self.norm is not None: + out = self.norm(out) + if self.act is not None: + out = self.act(out) + return out + + +class ResBlock(nn.Module): + def __init__(self, num_features, temb_ch): + super(ResBlock, self).__init__() + self.layers_1 = \ + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, + act='ReLU', padding=1, temb_ch=temb_ch) + self.layers_2 = \ + ConvBNReLU2D(num_features, out_channels=num_features, kernel_size=3, + padding=1, temb_ch=temb_ch) + + + + def forward(self, inputs, temp=None): + out = self.layers_1(inputs, temp) + out = self.layers_2(out, temp) + + return F.relu(out + inputs) + + +class DownSample(nn.Module): + def __init__(self, num_features, act, norm, scale=2): + super(DownSample, self).__init__() + if scale == 1: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + else: + self.layers = nn.Sequential( + ConvBNReLU2D(in_channels=num_features, out_channels=num_features, kernel_size=3, act=act, norm=norm, padding=1), + invPixelShuffle(ratio=scale), + ConvBNReLU2D(in_channels=num_features * scale * scale, out_channels=num_features, kernel_size=1, act=act, norm=norm) + ) + + def forward(self, inputs): + return self.layers(inputs) + + +class ResidualGroup(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act,norm, n_resblocks, temb_ch=None): + super(ResidualGroup, self).__init__() + self.body = nn.ModuleList([ + ResBlock(n_feat, temb_ch) for _ in range(n_resblocks)]) + + + self.end = ConvBNReLU2D(n_feat, n_feat, kernel_size, padding=1, act=act, + norm=norm, temb_ch=temb_ch) + + + self.re_scale = Scale(1) + + def forward(self, x, temp): + # res = self.body(x) + res = x + for block in self.body: + res = block(res, temp) + res = self.end(res, temp) + return res + self.re_scale(x) + + + +class FreBlock9(nn.Module): + def __init__(self, channels): + super(FreBlock9, self).__init__() + + self.fpre = nn.Conv2d(channels, channels, 1, 1, 0) + self.amp_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.pha_fuse = nn.Sequential(nn.Conv2d(channels, channels, 3, 1, 1), nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels, channels, 3, 1, 1)) + self.post = nn.Conv2d(channels, channels, 1, 1, 0) + + + def forward(self, x, temp=None): + # print("x: ", x.shape) + _, _, H, W = x.shape + msF = torch.fft.rfft2(self.fpre(x)+1e-8, norm='backward') + + msF_amp = torch.abs(msF) + msF_pha = torch.angle(msF) + # print("msf_amp: ", msF_amp.shape) + amp_fuse = self.amp_fuse(msF_amp) + # print(amp_fuse.shape, msF_amp.shape) + amp_fuse = amp_fuse + msF_amp + pha_fuse = self.pha_fuse(msF_pha) + pha_fuse = pha_fuse + msF_pha + + real = amp_fuse * torch.cos(pha_fuse)+1e-8 + imag = amp_fuse * torch.sin(pha_fuse)+1e-8 + out = torch.complex(real, imag)+1e-8 + out = torch.abs(torch.fft.irfft2(out, s=(H, W), norm='backward')) + out = self.post(out) + out = out + x + out = torch.nan_to_num(out, nan=1e-5, posinf=1e-5, neginf=1e-5) + # print("out: ", out.shape) + return out + +class Attention(nn.Module): + def __init__(self, dim=64, num_heads=8, bias=False): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) + self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias) + self.q = nn.Conv2d(dim, dim , kernel_size=1, bias=bias) + self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x, y): + b, c, h, w = x.shape + + kv = self.kv_dwconv(self.kv(y)) + k, v = kv.chunk(2, dim=1) + q = self.q_dwconv(self.q(x)) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + +class FuseBlock7(nn.Module): + def __init__(self, channels): + super(FuseBlock7, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fre_att = Attention(dim=channels) + self.spa_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + ori = spa + fre = self.fre(fre) + spa = self.spa(spa) + fre = self.fre_att(fre, spa)+fre + spa = self.fre_att(spa, fre)+spa + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class FuseBlock6(nn.Module): + def __init__(self, channels): + super(FuseBlock6, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + fre_a, spa_a = fuse.chunk(2, dim=1) + spa = spa_a * spa + fre = fre * fre_a + res = fre + spa + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.fre = nn.Conv2d(channels, channels, 3, 1, 1) + self.spa = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, spa, fre): + fre = self.fre(fre) + spa = self.spa(spa) + + fuse = self.fuse(torch.cat((fre, spa), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock7(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock7, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t1_att = Attention(dim=channels) + self.t2_att = Attention(dim=channels) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + t1 = self.t1_att(t1, t2)+t1 + t2 = self.t2_att(t2, t1)+t2 + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + +class Modality_FuseBlock6(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock6, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, 2*channels, 3, 1, 1), nn.Sigmoid()) + + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + t1_a, t2_a = fuse.chunk(2, dim=1) + t2 = t2_a * t2 + t1 = t1 * t1_a + res = t1 + t2 + + res = torch.nan_to_num(res, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + + +class Modality_FuseBlock4(nn.Module): + def __init__(self, channels): + super(Modality_FuseBlock4, self).__init__() + self.t1 = nn.Conv2d(channels, channels, 3, 1, 1) + self.t2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.fuse = nn.Sequential(nn.Conv2d(2*channels, channels, 3, 1, 1), nn.Conv2d(channels, channels, 3, 1, 1)) + + def forward(self, t1, t2): + t1 = self.t1(t1) + t2 = self.t2(t2) + + fuse = self.fuse(torch.cat((t1, t2), 1)) + res = torch.nan_to_num(fuse, nan=1e-5, posinf=1e-5, neginf=1e-5) + return res + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/st_branch_model/model.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/st_branch_model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..1ade669aa7079d8bb9a8c65e6190ae92a1bfd639 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/st_branch_model/model.py @@ -0,0 +1,786 @@ +import torch, math +from torch import nn +from . import common_freq as common +import torch.nn.functional as F + +from .utils import adopt_weight, hinge_d_loss, vanilla_d_loss +from metrics.lpips import LPIPS +# from vq_gan_3d.model.codebook import Codebook +import numpy as np + + +def silu(x): + return x * torch.sigmoid(x) + + +class SiLU(nn.Module): + def __init__(self): + super(SiLU, self).__init__() + + def forward(self, x): + return silu(x) + + +class DownBlock(nn.Module): + def __init__(self, num_features, act=True, norm=True, fre_layer=False, + kernel_size=3, reduction = 4, num_every_group=1, temb_ch=None, + spa_norm=None, spa_act=None): + super(DownBlock, self).__init__() + + self.downsample = common.DownSample(num_features, act, norm) + + self.fre_layer = None + self.spa_layer = None + if fre_layer: + self.fre_layer = common.FreBlock9(num_features) + else: + self.spa_layer = common.ResidualGroup( + num_features, kernel_size, reduction, act=spa_act, + n_resblocks=num_every_group, norm=spa_norm, temb_ch=temb_ch) + + def forward(self, x, temp=None): + out = self.downsample(x) + if self.fre_layer is not None: + out = self.fre_layer(out, temp) + else: + out = self.spa_layer(out, temp) + return out + + + +class UpBlock(nn.Module): + def __init__(self, scale, num_features, act=True, norm=True, fre_layer=False, + kernel_size=3, reduction=4, num_every_group=1, temb_ch=None, + spa_norm=None, spa_act=None + ): + super(UpBlock, self).__init__() + + self.upsample = common.UpSampler(scale, num_features) + self.fre_layer = None + self.spa_layer = None + if fre_layer: + self.fre_layer = common.FreBlock9(num_features) + else: + self.spa_layer = common.ResidualGroup( + num_features, kernel_size, reduction, act=spa_act, + n_resblocks=num_every_group, norm=spa_norm, temb_ch=temb_ch) + + + def forward(self, x, temp=None): + out = self.upsample(x) + if self.fre_layer is not None: + out = self.fre_layer(out, temp) + else: + out = self.spa_layer(out, temp) + return out + + + +class DuplicateBlock(nn.Module): + def __init__(self, block, num_of_block, **kwargs): + super(DuplicateBlock, self).__init__() + + self.blocks = nn.ModuleList([block(**kwargs) for _ in range(num_of_block)]) + + def forward(self, x, temp=None): + for block in self.blocks: + x = block(x, temp) + return x + + +class ModelBackbone(nn.Module): + def __init__(self, num_features, act, base_num_every_group, num_channels, temb_ch): + super(ModelBackbone, self).__init__() + + self.num_features = num_features + self.act = act + self.num_channels = num_channels + self.temb_ch = temb_ch + + # self.args = args + num_every_group = base_num_every_group + + self.init_T2_frq_branch() + self.init_T2_spa_branch(num_every_group) + self.init_T2_fre_spa_fusion() + + self.init_T1_frq_branch() + self.init_T1_spa_branch(num_every_group) + + self.init_modality_fre_fusion() + self.init_modality_spa_fusion() + + def init_T2_frq_branch(self): + ### T2frequency branch + self.head_fre = common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act, temb_ch=self.temb_ch) + self.down1_fre = DownBlock(self.num_features, False, False, fre_layer=True) + + self.down1_fre_mo = common.FreBlock9(self.num_features) + + self.down2_fre = DownBlock(self.num_features, False, False, fre_layer=True) + + self.down2_fre_mo = common.FreBlock9(self.num_features) + + self.down3_fre = DownBlock(self.num_features, False, False, fre_layer=True) + + self.down3_fre_mo = common.FreBlock9(self.num_features) + + self.neck_fre = common.FreBlock9(self.num_features) + + self.neck_fre_mo = common.FreBlock9(self.num_features) + + ### T2frequency branch + self.head_fre = common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act, temb_ch=self.temb_ch) + self.down1_fre = DownBlock(self.num_features, False, False, fre_layer=True) + + self.down1_fre_mo = common.FreBlock9(self.num_features) + + self.down2_fre = DownBlock(self.num_features, False, False, fre_layer=True) + + self.down2_fre_mo = common.FreBlock9(self.num_features) + + self.down3_fre = DownBlock(self.num_features, False, False, fre_layer=True) + + self.down3_fre_mo = common.FreBlock9(self.num_features) + + self.neck_fre = common.FreBlock9(self.num_features) + + self.neck_fre_mo = common.FreBlock9(self.num_features) + + self.up1_fre = UpBlock(2, self.num_features, fre_layer=True) + self.up1_fre_mo = common.FreBlock9(self.num_features) + + + self.up2_fre = UpBlock(2, self.num_features, fre_layer=True) + self.up2_fre_mo = common.FreBlock9(self.num_features) + + + self.up3_fre = UpBlock(2, self.num_features, fre_layer=True) + self.up3_fre_mo = nn.Sequential(common.FreBlock9(self.num_features)) + + # define tail module + self.tail_fre = common.ConvBNReLU2D(self.num_features, out_channels=self.num_channels, kernel_size=3, padding=1, + act=self.act, temb_ch=self.temb_ch) + + + + self.up1_fre = UpBlock(2, self.num_features, fre_layer=True) + self.up1_fre_mo = common.FreBlock9(self.num_features) + + self.up2_fre = UpBlock(2, self.num_features, fre_layer=True) + + self.up2_fre_mo = common.FreBlock9(self.num_features) + + self.up3_fre = UpBlock(2, self.num_features, fre_layer=True) + self.up3_fre_mo = common.FreBlock9(self.num_features) + + # define tail module + self.tail_fre = common.ConvBNReLU2D(self.num_features, out_channels=self.num_channels, kernel_size=3, padding=1, + act=self.act, temb_ch=self.temb_ch) + + + + def init_T2_spa_branch(self, num_every_group): + ### spatial branch + + self.head = common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act, temb_ch=self.temb_ch) + + self.down1 = DownBlock(self.num_features, act=False, norm=False, fre_layer=False, + kernel_size=3, reduction=4, + num_every_group=num_every_group, temb_ch=None, spa_norm=None, spa_act=self.act) + + + self.down1_mo = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + self.down2 = DownBlock(self.num_features, act=False, norm=False, fre_layer=False, + kernel_size=3, reduction=4, + num_every_group=num_every_group, temb_ch=None, spa_norm=None, spa_act=self.act) + + self.down2_mo = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + self.down3 = DownBlock(self.num_features, act=False, norm=False, fre_layer=False, + kernel_size=3, reduction=4, + num_every_group=num_every_group, temb_ch=None, spa_norm=None, spa_act=self.act) + + self.down3_mo = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + self.neck = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + self.neck_mo = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + self.up1 = UpBlock(2, self.num_features, act=None, norm=None, fre_layer=False, + kernel_size=3, reduction=4, num_every_group=num_every_group, temb_ch=self.temb_ch, + spa_norm=None, spa_act=self.act) + + + self.up1_mo = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + self.up2 = UpBlock(2, self.num_features, act=None, norm=None, fre_layer=False, + kernel_size=3, reduction=4, num_every_group=num_every_group, temb_ch=self.temb_ch, + spa_norm=None, spa_act=self.act) + + self.up2_mo = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + + self.up3 = UpBlock(2, self.num_features, act=None, norm=None, fre_layer=False, + kernel_size=3, reduction=4, num_every_group=num_every_group, temb_ch=self.temb_ch) + + self.up3_mo = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + # define tail module + self.tail = common.ConvBNReLU2D(self.num_features, out_channels=self.num_channels, kernel_size=3, padding=1, + act=self.act, temb_ch=self.temb_ch) + + + + + + def init_T1_frq_branch(self): + ### T2frequency branch + self.head_fre_T1 = common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act, temb_ch=self.temb_ch) + + + self.down1_fre_T1 = DownBlock(self.num_features, act=False, norm=False, fre_layer=True, temb_ch=self.temb_ch) + + self.down1_fre_mo_T1 = common.FreBlock9(self.num_features) + + + self.down2_fre_T1 = DownBlock(self.num_features, act=False, norm=False, fre_layer=True, temb_ch=self.temb_ch) + + self.down2_fre_mo_T1 = common.FreBlock9(self.num_features) + + + self.down3_fre_T1 = DownBlock(self.num_features, act=False, norm=False, fre_layer=True, temb_ch=self.temb_ch) + self.down3_fre_mo_T1 = common.FreBlock9(self.num_features) + + + self.neck_fre_T1 = common.FreBlock9(self.num_features) + self.neck_fre_mo_T1 = common.FreBlock9(self.num_features) + + + + + def init_T1_spa_branch(self, num_every_group): + ### spatial branch + + self.head_T1 = common.ConvBNReLU2D(1, out_channels=self.num_features, + kernel_size=3, padding=1, act=self.act, temb_ch=self.temb_ch) + + + self.down1_T1 = DownBlock(self.num_features, act=False, norm=False, fre_layer=False, + kernel_size=3, reduction=4, num_every_group=num_every_group, temb_ch=self.temb_ch, + spa_norm=None, spa_act=self.act) + + + self.down1_mo_T1 = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, + n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + + self.down2_T1 = DownBlock(self.num_features, act=False, norm=False, fre_layer=False, + kernel_size=3, reduction=4, num_every_group=num_every_group, temb_ch=self.temb_ch, + spa_norm=None, spa_act=self.act) + + + self.down2_mo_T1 = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, + norm=None, temb_ch=self.temb_ch) + + + self.down3_T1 = DownBlock(self.num_features, act=False, norm=False, fre_layer=False, + kernel_size=3, reduction=4, num_every_group=num_every_group, temb_ch=self.temb_ch, + spa_norm=None, spa_act=self.act) + + + self.down3_mo_T1 = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, + norm=None, temb_ch=self.temb_ch) + + + self.neck_T1 = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, norm=None, temb_ch=self.temb_ch) + + + self.neck_mo_T1 = common.ResidualGroup( + self.num_features, 3, 4, act=self.act, n_resblocks=num_every_group, + norm=None, temb_ch=self.temb_ch) + + + def init_T2_fre_spa_fusion(self): + ### T2 frq & spa fusion part + conv_fuse = [] + for i in range(14): + conv_fuse.append(common.FuseBlock7(self.num_features)) + self.conv_fuse = nn.Sequential(*conv_fuse) + + def init_modality_fre_fusion(self): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(self.num_features)) + self.conv_fuse_fre = nn.Sequential(*conv_fuse) + + def init_modality_spa_fusion(self): + conv_fuse = [] + for i in range(5): + conv_fuse.append(common.Modality_FuseBlock6(self.num_features)) + self.conv_fuse_spa = nn.Sequential(*conv_fuse) + + + # def init_T2_fre_spa_fusion(self): + # ### T2 frq & spa fusion part + # self.conv_fuse = DuplicateBlock(common.FuseBlock7, 14, + # channels=self.num_features) + # + # def init_modality_fre_fusion(self): + # self.conv_fuse_fre = DuplicateBlock(common.FuseBlock6, 5, channels=self.num_features) + # + # def init_modality_spa_fusion(self): + # self.conv_fuse_spa = DuplicateBlock(common.FuseBlock6, 5, channels=self.num_features) + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + + +class TwoBranchModel(nn.Module): + def __init__(self, num_features, act, base_num_every_group, num_channels): + super(TwoBranchModel, self).__init__() + + num_group = 4 + self.use_fre_mix = False + self.ch = num_channels + self.temb_ch = num_channels * 4 + + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(num_channels, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ]) + + + self.model = ModelBackbone(num_features, act, + base_num_every_group, num_channels, + temb_ch=self.temb_ch) + + + + + def forward(self, main, aux, t): + + # timestep embedding + temb =None + + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + # + + #### T1 fre encoder # T1 + temb_fre = None + t1_fre = self.model.head_fre_T1(aux, temb_fre) # 128 + + down1_fre_t1 = self.model.down1_fre_T1(t1_fre, temb_fre)# 64 + down1_fre_mo_t1 = self.model.down1_fre_mo_T1(down1_fre_t1, temb_fre) + + down2_fre_t1 = self.model.down2_fre_T1(down1_fre_mo_t1, temb_fre) # 32 + down2_fre_mo_t1 = self.model.down2_fre_mo_T1(down2_fre_t1, temb_fre) + + down3_fre_t1 = self.model.down3_fre_T1(down2_fre_mo_t1, temb_fre) # 16 + down3_fre_mo_t1 = self.model.down3_fre_mo_T1(down3_fre_t1, temb_fre) + + neck_fre_t1 = self.model.neck_fre_T1(down3_fre_mo_t1, temb_fre) # 16 + neck_fre_mo_t1 = self.model.neck_fre_mo_T1(neck_fre_t1, temb_fre) + + + #### T2 fre encoder and T1 & T2 fre fusion + x_fre = self.model.head_fre(main, temb) # 128 + x_fre_fuse = self.model.conv_fuse_fre[0](t1_fre, x_fre) + + down1_fre = self.model.down1_fre(x_fre_fuse, temb)# 64 + down1_fre_mo = self.model.down1_fre_mo(down1_fre, temb) + down1_fre_mo_fuse = self.model.conv_fuse_fre[1](down1_fre_mo_t1, down1_fre_mo) + + down2_fre = self.model.down2_fre(down1_fre_mo_fuse, temb) # 32 + down2_fre_mo = self.model.down2_fre_mo(down2_fre, temb) + down2_fre_mo_fuse = self.model.conv_fuse_fre[2](down2_fre_mo_t1, down2_fre_mo) + + down3_fre = self.model.down3_fre(down2_fre_mo_fuse, temb) # 16 + down3_fre_mo = self.model.down3_fre_mo(down3_fre, temb) + down3_fre_mo_fuse = self.model.conv_fuse_fre[3](down3_fre_mo_t1, down3_fre_mo) + + neck_fre = self.model.neck_fre(down3_fre_mo_fuse, temb) # 16 + neck_fre_mo = self.model.neck_fre_mo(neck_fre, temb) + neck_fre_mo_fuse = self.model.conv_fuse_fre[4](neck_fre_mo_t1, neck_fre_mo) + + + #### T2 fre decoder + neck_fre_mo = neck_fre_mo_fuse + down3_fre_mo_fuse + + up1_fre = self.model.up1_fre(neck_fre_mo, temb) # 32 + up1_fre_mo = self.model.up1_fre_mo(up1_fre, temb) + up1_fre_mo = up1_fre_mo + down2_fre_mo_fuse + + up2_fre = self.model.up2_fre(up1_fre_mo, temb) # 64 + up2_fre_mo = self.model.up2_fre_mo(up2_fre, temb) + up2_fre_mo = up2_fre_mo + down1_fre_mo_fuse + + up3_fre = self.model.up3_fre(up2_fre_mo, temb) # 128 + up3_fre_mo = self.model.up3_fre_mo(up3_fre, temb) + up3_fre_mo = up3_fre_mo + x_fre_fuse + + res_fre = self.model.tail_fre(up3_fre_mo, temb) + + #### T1 spa encoder + x_t1 = self.model.head_T1(aux, temb) # 128 + + down1_t1 = self.model.down1_T1(x_t1, temb) # 64 + down1_mo_t1 = self.model.down1_mo_T1(down1_t1, temb) + + down2_t1 = self.model.down2_T1(down1_mo_t1, temb) # 32 + down2_mo_t1 = self.model.down2_mo_T1(down2_t1, temb) # 32 + + down3_t1 = self.model.down3_T1(down2_mo_t1, temb) # 16 + down3_mo_t1 = self.model.down3_mo_T1(down3_t1, temb) # 16 + + neck_t1 = self.model.neck_T1(down3_mo_t1, temb) # 16 + neck_mo_t1 = self.model.neck_mo_T1(neck_t1, temb) + + #### T2 spa encoder and fusion + x = self.model.head(main, temb) # 128 + + x_fuse = self.model.conv_fuse_spa[0](x_t1, x) + down1 = self.model.down1(x_fuse, temb) # 64 + down1_fuse = self.model.conv_fuse[0](down1_fre, down1) + down1_mo = self.model.down1_mo(down1_fuse, temb) + down1_fuse_mo = self.model.conv_fuse[1](down1_fre_mo_fuse, down1_mo) + + down1_fuse_mo_fuse = self.model.conv_fuse_spa[1](down1_mo_t1, down1_fuse_mo) + down2 = self.model.down2(down1_fuse_mo_fuse, temb) # 32 + down2_fuse = self.model.conv_fuse[2](down2_fre, down2) + down2_mo = self.model.down2_mo(down2_fuse, temb) # 32 + down2_fuse_mo = self.model.conv_fuse[3](down2_fre_mo, down2_mo) + + down2_fuse_mo_fuse = self.model.conv_fuse_spa[2](down2_mo_t1, down2_fuse_mo) + down3 = self.model.down3(down2_fuse_mo_fuse, temb) # 16 + down3_fuse = self.model.conv_fuse[4](down3_fre, down3) + down3_mo = self.model.down3_mo(down3_fuse, temb) # 16 + down3_fuse_mo = self.model.conv_fuse[5](down3_fre_mo, down3_mo) + + down3_fuse_mo_fuse = self.model.conv_fuse_spa[3](down3_mo_t1, down3_fuse_mo) + neck = self.model.neck(down3_fuse_mo_fuse, temb) # 16 + neck_fuse = self.model.conv_fuse[6](neck_fre, neck) + neck_mo = self.model.neck_mo(neck_fuse, temb) + neck_mo = neck_mo + down3_mo + neck_fuse_mo = self.model.conv_fuse[7](neck_fre_mo, neck_mo) + + neck_fuse_mo_fuse = self.model.conv_fuse_spa[4](neck_mo_t1, neck_fuse_mo) + #### T2 spa decoder + up1 = self.model.up1(neck_fuse_mo_fuse, temb) # 32 + up1_fuse = self.model.conv_fuse[8](up1_fre, up1) + up1_mo = self.model.up1_mo(up1_fuse, temb) + up1_mo = up1_mo + down2_mo + up1_fuse_mo = self.model.conv_fuse[9](up1_fre_mo, up1_mo) + + up2 = self.model.up2(up1_fuse_mo, temb) # 64 + up2_fuse = self.model.conv_fuse[10](up2_fre, up2) + up2_mo = self.model.up2_mo(up2_fuse, temb) + up2_mo = up2_mo + down1_mo + up2_fuse_mo = self.model.conv_fuse[11](up2_fre_mo, up2_mo) + + up3 = self.model.up3(up2_fuse_mo, temb) # 128 + + up3_fuse = self.model.conv_fuse[12](up3_fre, up3) + up3_mo = self.model.up3_mo(up3_fuse, temb) + + up3_mo = up3_mo + x + up3_fuse_mo = self.model.conv_fuse[13](up3_fre_mo, up3_mo) + + res = self.model.tail(up3_fuse_mo, temb) + + # if self.use_res: + # res = res + # res_fre = res_fre + + return res + main, res_fre + main + + + def training_step(self, batch, batch_idx, optimizer_idx): + self.model.train() + + x = batch['image'] + aux = batch['aux'] + + x = x.squeeze(1) + aux = aux.squeeze(1) + + # print(x.shape) + # print(aux.shape) + + # torch.Size([8, 96, 96, 1]) + # torch.Size([16, 1, 96, 96, 96]) + + x = x.permute(0, -1, -3, -2)#.detach() # [B, C, H, W] + aux = aux.permute(0, -1, -3, -2)#.detach() # [B, C, H, W] + + out = self.forward(x, aux) + recon_out = out['recon_out'] + recon_fre = out['recon_fre'] + + if optimizer_idx == 0: + fft_weight = 0.01 + use_dis = False + recon_out_loss = self.get_recon_loss(recon_out, x, tag="recon_out", use_dis=use_dis) + recon_fre_loss = self.get_recon_loss(recon_fre, x, tag="recon_fre", use_dis=use_dis) + # amp = self.amploss(recon_fre, x) + # pha = self.phaloss(recon_fre, x) + loss = recon_out_loss + recon_fre_loss #+ fft_weight * ( amp + pha ) + + elif optimizer_idx == 1: + loss = self.get_dis_loss(recon_out, x, tag="dis") + + # print("loss = ", loss) + + return loss + + + def get_dis_loss(self, recon, target, tag="dis"): + B, C, H, W = recon.shape + # Selects one random 2D image from each 3D Image + + logits_image_real, _ = self.image_discriminator(target.detach()) + logits_image_fake, _ = self.image_discriminator(recon.detach()) + + print("logits_image_real = ", torch.mean(logits_image_real)) + print("logits_image_fake = ", torch.mean(logits_image_fake)) + + d_image_loss = self.disc_loss(logits_image_real , logits_image_fake) + # print("d_image_loss = ", d_image_loss) + + # d_video_loss = self.disc_loss(logits_video_real, logits_video_fake) + disc_factor = adopt_weight( + self.global_step, threshold=self.discriminator_iter_start) + discloss = disc_factor * (self.image_gan_weight * d_image_loss ) + + self.log(f"train/{tag}/logits_image_real", logits_image_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log(f"train/{tag}/logits_image_fake", logits_image_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log(f"train/{tag}/d_image_loss", d_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log(f"train/{tag}/disc_loss", discloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + return discloss + + + def get_recon_loss(self, recon, target, tag="recon_out", use_dis=True): + recon_loss = F.l1_loss(recon, target) * self.l1_weight + # recon_loss = ((recon - target)**2).mean() * self.l1_weight + + # Perceptual loss + perceptual_loss = 0 + aeloss = 0 + image_gan_feat_loss = 0 + g_image_loss = 0 + + # Slice it into T, H, W random slices + if self.perceptual_weight > 0: + B, C, H, W = recon.shape + # Selects one random 2D image from each 3D Image + + perceptual_loss = self.perceptual_model(recon, target).mean() * self.perceptual_weight + recon_loss += perceptual_loss + + + if use_dis: + # Discriminator loss (turned on after a certain epoch) + logits_image_fake, pred_image_fake = self.image_discriminator(recon) + # logits_video_fake, pred_video_fake = self.video_discriminator(recon) + g_image_loss = -torch.mean(logits_image_fake) + # g_video_loss = -torch.mean(logits_video_fake) + g_loss = self.image_gan_weight * g_image_loss + + disc_factor = adopt_weight( + self.global_step, threshold=self.discriminator_iter_start) + aeloss = disc_factor * g_loss + + # GAN feature matching loss - tune features such that we get the same prediction result on the discriminator + image_gan_feat_loss = 0 + feat_weights = 4.0 / (3 + 1) + if self.image_gan_weight > 0: + logits_image_real, pred_image_real = self.image_discriminator( recon) + for i in range(len(pred_image_fake)-1): + image_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_image_fake[i], pred_image_real[i].detach( + )) * (self.image_gan_weight > 0) + + gan_feat_loss = disc_factor * self.gan_feat_weight * (image_gan_feat_loss) + recon_loss += gan_feat_loss + aeloss + + self.log(f"train/{tag}/g_image_loss", g_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log(f"train/{tag}/image_gan_feat_loss", image_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log(f"train/{tag}/aeloss", aeloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log(f"train/{tag}/perceptual_loss", perceptual_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log(f"train/{tag}/recon_loss", recon_loss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + return recon_loss + + +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw - 1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ + + +class NLayerDiscriminator3D(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator3D, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/networks/st_branch_model/utils.py b/MRI_recon/new_code/Frequency-Diffusion-main/networks/st_branch_model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c92a018e6b6409617ff7982339736b9db36c7fa --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/networks/st_branch_model/utils.py @@ -0,0 +1,220 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import warnings +import torch +import imageio + +import math +import numpy as np +import skvideo.io + +import sys +import pdb as pdb_original +import SimpleITK as sitk +import logging +from torch import nn +import torch.nn.functional as F + +import imageio.core.util +logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) + + +class ForkedPdb(pdb_original.Pdb): + """A Pdb subclass that may be used + from a forked multiprocessing child + + """ + + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + pdb_original.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + +# Shifts src_tf dim to dest dim +# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + + dims = list(range(n_dims)) + del dims[src_dim] + + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + + +# reshapes tensor start from dim i (inclusive) +# to dim j (exclusive) to the desired shape +# e.g. if x.shape = (b, thw, c) then +# view_range(x, 1, 2, (t, h, w)) returns +# x of shape (b, t, h, w, c) +def view_range(x, i, j, shape): + shape = tuple(shape) + + n_dims = len(x.shape) + if i < 0: + i = n_dims + i + + if j is None: + j = n_dims + elif j < 0: + j = n_dims + j + + assert 0 <= i < j <= n_dims + + x_shape = x.shape + target_shape = x_shape[:i] + shape + x_shape[j:] + return x.view(target_shape) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def tensor_slice(x, begin, size): + assert all([b >= 0 for b in begin]) + size = [l - b if s == -1 else s + for s, b, l in zip(size, begin, x.shape)] + assert all([s >= 0 for s in size]) + + slices = [slice(b, b + s) for b, s in zip(begin, size)] + return x[slices] + + +def adopt_weight(global_step, threshold=0, value=0.): + weight = 1 + if global_step < threshold: + weight = value + return weight + + +def save_video_grid(video, fname, nrow=None, fps=6): + b, c, t, h, w = video.shape + video = video.permute(0, 2, 3, 4, 1) + video = (video.cpu().numpy() * 255).astype('uint8') + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = np.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype='uint8') + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + video = [] + for i in range(t): + video.append(video_grid[i]) + imageio.mimsave(fname, video, fps=fps) + ## skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'}) + #print('saved videos to', fname) + + +def comp_getattr(args, attr_name, default=None): + if hasattr(args, attr_name): + return getattr(args, attr_name) + else: + return default + + +def visualize_tensors(t, name=None, nest=0): + if name is not None: + print(name, "current nest: ", nest) + print("type: ", type(t)) + if 'dict' in str(type(t)): + print(t.keys()) + for k in t.keys(): + if t[k] is None: + print(k, "None") + else: + if 'Tensor' in str(type(t[k])): + print(k, t[k].shape) + elif 'dict' in str(type(t[k])): + print(k, 'dict') + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t[k])): + print(k, len(t[k])) + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t)): + print("list length: ", len(t)) + for t2 in t: + visualize_tensors(t2, name, nest + 1) + elif 'Tensor' in str(type(t)): + print(t.shape) + else: + print(t) + return "" + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + + + +class PhaLoss(nn.Module): + def __init__(self, epsilon=1e-8, norm='ortho'): + super(PhaLoss, self).__init__() + self.cri = nn.L1Loss() + self.epsilon = epsilon # To prevent undefined phase for zero magnitudes + self.norm = norm # Normalization for FFT + + def forward(self, x, y): + # Validate inputs + if not torch.isfinite(x).all() or not torch.isfinite(y).all(): + raise ValueError("Input contains NaN or Inf values") + + # Perform FFT + x_fft = torch.fft.rfft2(x, norm=self.norm) + y_fft = torch.fft.rfft2(y, norm=self.norm) + + # Compute phase + x_phase = torch.angle(x_fft) + y_phase = torch.angle(y_fft) + + # Compute L1 loss between phases + return self.cri(x_phase, y_phase) diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/reproduce.sh b/MRI_recon/new_code/Frequency-Diffusion-main/reproduce.sh new file mode 100644 index 0000000000000000000000000000000000000000..88a257dcf5559b4ba7a5d59ad6111ae25aa20b18 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/reproduce.sh @@ -0,0 +1,34 @@ +mamba create -n diffmri python=3.10 -y + +mamba activate diffmri + + +pip install -r ./requirements.txt + + +# pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 \ + # --extra-index-url https://download.pytorch.org/whl/cu113 + + + +pip install torch==2.4.0+cu124 torchvision==0.19.0+cu124 torchaudio==2.4.0+cu124 \ + --extra-index-url https://download.pytorch.org/whl/cu124 + + + +pip install comet_ml torchgeometry albumentations +pip install --upgrade matplotlib +pip install --upgrade scikit-learn pytorch_msssim + + +pip install --upgrade pandas scipy scikit-image scikit-video scipy pytorch_lightning einops SimpleITK +pip install h5py fastmri torchmetrics +# torchmetric +pip install wandb==0.19 +pip install tensorboardX timm ml_collections + +# DR: Retinal Fundus Imaging +# OCT + Retinal Fundus Imaging + + + diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/requirements.txt b/MRI_recon/new_code/Frequency-Diffusion-main/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..955346e0e023adaec83beaa0ac38ab8cbb7be7ac --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/requirements.txt @@ -0,0 +1,104 @@ +absl-py==1.1.0 +accelerate==0.11.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +antlr4-python3-runtime==4.9.3 +async-timeout==4.0.2 +attrs==21.4.0 +autopep8==1.6.0 +cachetools==5.2.0 +certifi==2022.6.15 +charset-normalizer==2.0.12 +click==8.1.3 +cycler==0.11.0 +Deprecated==1.2.13 +docker-pycreds==0.4.0 +einops==0.4.1 +einops-exts==0.0.3 +ema-pytorch==0.0.8 +fonttools==4.34.4 +frozenlist==1.3.0 +fsspec==2022.5.0 +ftfy==6.1.1 +future==0.18.2 +gitdb==4.0.9 +GitPython==3.1.27 +google-auth==2.9.0 +google-auth-oauthlib==0.4.6 +grpcio==1.47.0 +h5py==3.7.0 +humanize==4.2.2 +hydra-core==1.2.0 +idna==3.3 +imageio==2.19.3 +imageio-ffmpeg==0.4.7 +importlib-metadata==4.12.0 +importlib-resources==5.9.0 +joblib==1.1.0 +kiwisolver==1.4.3 +lxml==4.9.1 +Markdown==3.3.7 +matplotlib==3.5.2 +multidict==6.0.2 +networkx==2.8.5 +nibabel==4.0.1 +nilearn==0.9.1 +numpy==1.23.0 +oauthlib==3.2.0 +omegaconf==2.2.3 +pandas==1.4.3 +Pillow==9.1.1 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycodestyle==2.8.0 +pyDeprecate==0.3.1 +pydicom==2.3.0 +pytorch-lightning==1.9.2 +pytz==2022.1 +PyWavelets==1.3.0 +PyYAML==6.0 +pyzmq==19.0.2 +regex==2022.6.2 +requests==2.28.0 +requests-oauthlib==1.3.1 +rotary-embedding-torch==0.1.5 +rsa==4.8 +scikit-image==0.19.3 +scikit-learn==1.1.2 +scikit-video==1.1.11 +scipy==1.8.1 +seaborn==0.11.2 +sentry-sdk==1.7.2 +setproctitle==1.2.3 +shortuuid==1.0.9 +SimpleITK==2.1.1.2 +smmap==5.0.0 +tensorboard +tensorboard-data-server +tensorboard-plugin-wit +threadpoolctl==3.1.0 +tifffile==2022.8.3 +toml==0.10.2 +torch-tb-profiler==0.4.0 +torchio==0.18.80 +torchmetrics==0.9.1 +tqdm==4.64.0 +typing_extensions==4.2.0 +urllib3==1.26.9 +wandb==0.12.21 +Werkzeug==2.1.2 +wrapt==1.14.1 +yarl==1.7.2 +zipp==3.8.0 +tensorboardX==2.4.1 +protobuf==3.20.1 +sk-video +torchstat +timm +elasticdeform +opencv-python +monai[nibabel] +monai==0.9.0 +nibabel +ml-collections +glob2 diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/test.py b/MRI_recon/new_code/Frequency-Diffusion-main/test.py new file mode 100644 index 0000000000000000000000000000000000000000..96bda5d729e48687f7a1c4ffe111363b3af3da9e --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/test.py @@ -0,0 +1,182 @@ +from diffusion_pytorch import GaussianDiffusion, Trainer, Model +from Fid import calculate_fid_given_samples +import torchvision +import os +import errno +import shutil +import argparse + + +def create_folder(path): + try: + os.mkdir(path) + except OSError as exc: + if exc.errno != errno.EEXIST: + raise + pass + + +def del_folder(path): + try: + shutil.rmtree(path) + except OSError as exc: + pass + + +create = 0 + +if create: + trainset = torchvision.datasets.CIFAR10( + root='./data', train=False, download=True) + root = './root_cifar10_test/' + del_folder(root) + create_folder(root) + + for i in range(10): + lable_root = root + str(i) + '/' + create_folder(lable_root) + + for idx in range(len(trainset)): + img, label = trainset[idx] + print(idx) + img.save(root + str(label) + '/' + str(idx) + '.png') + + +parser = argparse.ArgumentParser() +parser.add_argument('--time_steps', default=50, type=int) +parser.add_argument('--sample_steps', default=None, type=int) +parser.add_argument('--kernel_std', default=0.1, type=float) +parser.add_argument('--save_folder', default='progression_cifar', type=str) +parser.add_argument('--load_path', default='/cmlscratch/eborgnia/cold_diffusion/paper_defading_random_1/model.pt', type=str) +parser.add_argument('--data_path', default='./root_cifar10_test/', type=str) +parser.add_argument('--test_type', default='test_paper_showing_diffusion_images_diff', type=str) +parser.add_argument('--fade_routine', default='Random_Incremental', type=str) +parser.add_argument('--sampling_routine', default='x0_step_down', type=str) +parser.add_argument('--remove_time_embed', action="store_true") +parser.add_argument('--discrete', action="store_true") +parser.add_argument('--residual', action="store_true") + +args = parser.parse_args() +print(args) + +img_path=None +if 'train' in args.test_type: + img_path = args.data_path +elif 'test' in args.test_type: + img_path = args.data_path + +print("Img Path is ", img_path) + + + +image_channels = 1 + +if model_name == "unet": + model = Model(resolution=args.image_size, + in_channels=1, + out_ch=1, + ch=128, + ch_mult=(1, 2, 2, 2), + num_res_blocks=2, + attn_resolutions=(16,), + dropout=0.1).cuda() + +elif model_name == "twounet": + + model = TwoBranchNewModel(resolution=args.image_size, + in_channels=1, + out_ch=1, + ch=128, + ch_mult=(1, 2, 2, 2), + num_res_blocks=2, + attn_resolutions=(16,), + dropout=0.1).cuda() + +elif model_name == "twobranch": + downsample = [4, 4, 4] + disc_channels = 64 + disc_layers = 3 + discriminator_iter_start = 10000 + disc_loss_type = "hinge" + image_gan_weight = 1.0 + video_gan_weight = 1.0 + l1_weight = 4.0 + gan_feat_weight = 4.0 + perceptual_weight = 4.0 + i3d_feat = False + restart_thres = 1.0 + no_random_restart = False + norm_type = "group" + padding_type = "replicate" + num_groups = 32 + + base_num_every_group = 2 + num_features = 64 + act = "PReLU" + num_channels = 1 + + model = TwoBranchModel( + image_channels, + disc_channels, disc_layers, disc_loss_type, + gan_feat_weight, image_gan_weight, + discriminator_iter_start, + perceptual_weight, l1_weight, + num_features, act, base_num_every_group, num_channels + ).cuda() + + +diffusion = GaussianDiffusion( + diffusion_type, + model, + image_size=args.image_size, # Used to be 32 + channels=image_channels, + device_of_kernel='cuda', + timesteps=args.time_steps, + loss_type=args.loss_type, #$'l1', + kernel_std=args.kernel_std, + fade_routine=args.fade_routine, + sampling_routine=args.sampling_routine, + discrete=args.discrete +).cuda() + + +trainer = Trainer( + diffusion, + img_path, + image_size = 32, + train_batch_size = 32, + train_lr = 2e-5, + train_num_steps = 700000, # total training steps + gradient_accumulate_every = 2, # gradient accumulation steps + ema_decay = 0.995, # exponential moving average decay + fp16 = False, # turn on mixed precision training with apex + results_folder = args.save_folder, + load_path = args.load_path +) + + + + +if args.test_type == 'train_data': + trainer.test_from_data('train', s_times=args.sample_steps) + +elif args.test_type == 'test_data': + trainer.test_from_data('test', s_times=args.sample_steps) + +elif args.test_type == 'mixup_train_data': + trainer.test_with_mixup('train') + +elif args.test_type == 'mixup_test_data': + trainer.test_with_mixup('test') + +elif args.test_type == 'test_random': + trainer.test_from_random('random') + +elif args.test_type == 'test_fid_distance_decrease_from_manifold': + trainer.fid_distance_decrease_from_manifold(calculate_fid_given_samples, start=0, end=None) + +elif args.test_type == 'test_paper_invert_section_images': + trainer.paper_invert_section_images() + +elif args.test_type == 'test_paper_showing_diffusion_images_diff': + trainer.paper_showing_diffusion_images() diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/visualization/eroor_map_.py b/MRI_recon/new_code/Frequency-Diffusion-main/visualization/eroor_map_.py new file mode 100644 index 0000000000000000000000000000000000000000..91fde841b6020435360c0e773e88934c360d7b72 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/visualization/eroor_map_.py @@ -0,0 +1,49 @@ +""" +读取图像的ground truth和每个round重建结果, 并且绘制error map. +""" +import os +import numpy as np +from PIL import Image +from skimage import io +from matplotlib import pyplot as plt + + +def normalize_image(image): + # return (image - image.min())/(image.max() - image.min()) + image = image[40:200, 55:215] + # image = image[80:160, 95:175] + print("image shape:", image.shape) + return image/255.0 + + +def viz_diff_img(image, test_outputdir, image_name): + print("image range:", image.max(), image.min()) + plt.imshow(image, cmap='jet') + plt.savefig(os.path.join(test_outputdir, f'{image_name}'), + bbox_inches='tight') + + +root_dir = "/data/xiaohan/BRATS_dataset/image_100patients_unimodal/" +image_name = "BraTS20_Training_042_60_t1" + +dst_dir = "./recon_image_visualization" + + +img_gt = normalize_image(np.array(Image.open(root_dir + image_name + ".png"))) +img_in = normalize_image(np.array(Image.open(root_dir + image_name + "_10dB.png"))) + +img_round1 = normalize_image(np.array(Image.open(root_dir + image_name + "_10dB_krecon_round1.png"))) + +print(img_gt.max(), img_gt.min()) +print(img_in.max(), img_in.min()) +print(img_round1.max(), img_round1.min()) + +io.imsave(os.path.join(dst_dir, image_name + ".png"), img_gt) +io.imsave(os.path.join(dst_dir, image_name + "_10dB.png"), img_in) +io.imsave(os.path.join(dst_dir, image_name + "_10dB_round1.png"), img_round1) + +viz_diff_img(np.abs(img_gt - img_in)*255, dst_dir, image_name + "_input_error.png") +viz_diff_img(np.abs(img_gt - img_round1)*255, dst_dir, image_name + "_round1_error.png") + +print("input error:", np.mean(np.abs(img_gt - img_in))) +print("round1 error:", np.mean(np.abs(img_gt - img_round1))) \ No newline at end of file diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/visualization/error_map.py b/MRI_recon/new_code/Frequency-Diffusion-main/visualization/error_map.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa3f468d86ff43154659f1d87e768896e0f0dd3 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/visualization/error_map.py @@ -0,0 +1,54 @@ +""" +读取图像的ground truth和每个round重建结果, 并且绘制error map. +""" +import os +import numpy as np +from PIL import Image +from skimage import io +from matplotlib import pyplot as plt + + +def normalize_image(image): + # return (image - image.min())/(image.max() - image.min()) + image = image[40:200, 55:215] + # image = image[80:160, 95:175] + print("image shape:", image.shape) + return image/255.0 + # return (image - image.min())/(image.max() - image.min()) + + +def viz_diff_img(image, test_outputdir, image_name): + print("image range:", image.max(), image.min()) + plt.axis('off') + # plt.imshow(image, cmap='jet',vmin=0, vmax=50) + plt.imshow(image, cmap='jet',vmin=0, vmax=30) + # plt.colorbar() + plt.savefig(os.path.join(test_outputdir, f'{image_name}'), bbox_inches='tight',pad_inches = 0) + +# baseline = 'UNet_4X' +# baseline_list = ['DCAMSR_4X', 'MCCA_4X', 'MINet_4X', 'MTrans_4X', 'swinir_4X_'] +baseline_list = ['DCAMSR_8X', 'MCCA_8X', 'MINet_8X', 'MTrans_8X', 'swinir_8X_'] +# baseline_list = ['swinir_8X_'] +baseline_list = ['our'] +for baseline in baseline_list: + # root_dir = f"/data/qic99/recon_code/recon_2M/BRATS_baseline/model/{baseline}/result_case/" + root_dir = '/data/qic99/recon_code/recon_2M/BRATS_freq_multi_fusion_2_2_4/model/unet_wo_kspace_4X_lr1e-4/result_case/' + # root_dir = '/data/qic99/recon_code/recon_2M/BRATS_freq_multi_fusion_2_2_4/model/unet_wo_kspace_8X_lr1e-4/result_case/' + image_name = "301_t2" + + dst_dir = "./error_map_8X" + # dst_dir = "./error_map_4X" + os.makedirs(dst_dir, exist_ok=True) + img_gt = normalize_image(np.array(Image.open(root_dir + image_name + ".png"))) + img_in = normalize_image(np.array(Image.open(root_dir + image_name + "_out.png"))) + img_lq = normalize_image(np.array(Image.open(root_dir + image_name + "_in.png"))) + + print(img_gt.max(), img_gt.min()) + print(img_in.max(), img_in.min()) + io.imsave(os.path.join(dst_dir, image_name + "_lq.png"), (img_lq*255).astype(np.uint8)) + io.imsave(os.path.join(dst_dir, image_name + ".png"), (img_gt*255).astype(np.uint8)) + io.imsave(os.path.join(dst_dir, baseline+'_'+image_name + "_out.png"), (img_in*255).astype(np.uint8)) + viz_diff_img(np.abs(img_gt - img_in)*255, dst_dir, baseline+'_'+image_name + "_error_map.png") + + print("input error:", np.mean(np.abs(img_gt - img_in))) + # break diff --git a/MRI_recon/new_code/Frequency-Diffusion-main/visualization/error_map2.py b/MRI_recon/new_code/Frequency-Diffusion-main/visualization/error_map2.py new file mode 100644 index 0000000000000000000000000000000000000000..62ab1dfff89d3d8f483ec948cb069ee44d0c70c0 --- /dev/null +++ b/MRI_recon/new_code/Frequency-Diffusion-main/visualization/error_map2.py @@ -0,0 +1,53 @@ +""" +读取图像的ground truth和每个round重建结果, 并且绘制error map. +""" +import os +import numpy as np +from PIL import Image +from skimage import io +from matplotlib import pyplot as plt + + +def normalize_image(image): + # return (image - image.min())/(image.max() - image.min()) + image = image[40:200, 55:215] + # image = image[80:160, 95:175] + print("image shape:", image.shape) + return image/255.0 + # return (image - image.min())/(image.max() - image.min()) + + +def viz_diff_img(image, test_outputdir, image_name): + print("image range:", image.max(), image.min()) + plt.axis('off') + # plt.imshow(image, cmap='jet',vmin=0, vmax=50) + plt.imshow(image, cmap='jet',vmin=0, vmax=80) + plt.savefig(os.path.join(test_outputdir, f'{image_name}'), bbox_inches='tight',pad_inches = 0) + +# baseline = 'UNet_4X' +baseline_list = ['DCAMSR_4x', 'MCCA_4x', 'MINet_4x', 'MTrans_4x', 'swinIR_4x'] +# baseline_list = ['DCAMSR_8X', 'MCCA_8X', 'MINet_8X', 'MTrans_8X', 'swinir_8X_'] +# baseline_list = ['swinir_8X_'] +baseline_list = ['our'] +for baseline in baseline_list: + # root_dir = f"/data/qic99/recon_code/recon_2M/fastMRI_baseline/model/{baseline}/result_case/" + root_dir = '/data/qic99/recon_code/recon_2M/BRATS_freq_multi_fusion_2_2_4/model/our_fastmri_4x/result_case/' + # root_dir = '/data/qic99/recon_code/recon_2M/BRATS_freq_multi_fusion_2_2_4/model/our_fastmri_8x/result_case/' + image_name = "file1001059_11" + + # dst_dir = "./fastMRI_error_map_8X" + dst_dir = "./fastMRI_error_map_4X" + os.makedirs(dst_dir, exist_ok=True) + img_gt = normalize_image(np.array(Image.open(root_dir + image_name + ".png"))) + img_in = normalize_image(np.array(Image.open(root_dir + image_name + "_out.png"))) + img_lq = normalize_image(np.array(Image.open(root_dir + image_name + "_in.png"))) + # breakpoint() + print(img_gt.max(), img_gt.min()) + print(img_in.max(), img_in.min()) + io.imsave(os.path.join(dst_dir, image_name + "_lq.png"), (img_lq).astype(np.uint8)) + io.imsave(os.path.join(dst_dir, image_name + ".png"), (img_gt).astype(np.uint8)) + io.imsave(os.path.join(dst_dir, baseline+'_'+image_name + "_out.png"), (img_in).astype(np.uint8)) + viz_diff_img(np.abs(img_gt - img_in), dst_dir, baseline+'_'+image_name + "_error_map.png") + + print("input error:", np.mean(np.abs(img_gt - img_in))) + # break